1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_fast_parsing.h"
16 
17 #include <vector>
18 
19 #include "absl/base/casts.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/example/example.pb.h"
22 #include "tensorflow/core/example/feature.pb.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/blocking_counter.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/monitoring/counter.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/util/presized_cuckoo_map.h"
37 #include "tensorflow/core/util/sparse/sparse_tensor.h"
38 
39 namespace tensorflow {
40 namespace example {
41 
42 namespace {
43 
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46 
47 template <typename T>
48 class LimitedArraySlice {
49  public:
50   using value_type = T;
51 
LimitedArraySlice(T * begin,size_t num_elements)52   LimitedArraySlice(T* begin, size_t num_elements)
53       : current_(begin), begin_(begin), end_(begin + num_elements) {}
54 
55   // May return negative if there were push_back calls after slice was filled.
EndDistance() const56   int64 EndDistance() const { return end_ - current_; }
57 
58   // Attempts to push value to the back of this. If the slice has
59   // already been filled, this method has no effect on the underlying data, but
60   // it changes the number returned by EndDistance into negative values.
push_back(T && value)61   void push_back(T&& value) {
62     if (EndDistance() > 0) *current_ = std::move(value);
63     ++current_;
64   }
65 
66   // "Constructs" an element at the back of this by resizing the slice, and
67   // returns a mutable reference to the new last element.
68   // REQUIRES: EndDistance() > 0.
construct_at_end()69   T& construct_at_end() {
70     DCHECK_GT(EndDistance(), 0);
71     return *(current_++);
72   }
73 
74   // Returns a mutable reference to the last element in the slice.
75   // REQUIRES: size() > 0.
back()76   T& back() { return *(current_ - 1); }
77 
78   // Returns the number of elements in the slice.
size() const79   size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80 
81   // Attempts to resize the vector to the given size. It does so by advancing
82   // the pointer to the current element, possibly beyond the end of the slice.
83   // As a consequence, calling `size()` after `resize(x)` was called might
84   // return a value less than `x`.
resize(size_t size)85   void resize(size_t size) { current_ = begin_ + size; }
86 
87   // Returns the pointer to the underlying data buffer.
data()88   T* data() { return begin_; }
89 
90  private:
91   T* current_;
92   T* begin_;
93   T* end_;
94 };
95 
96 template <typename A>
EnableAliasing(A * a)97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98   a->EnableAliasing(true);
99 }
100 
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103 
PeekTag(protobuf::io::CodedInputStream * stream)104 uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
105   DCHECK(stream != nullptr);
106   const void* ptr;
107   int size;
108   if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
109   return *static_cast<const uint8*>(ptr);
110 }
111 
kVarintTag(uint32 tag)112 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)113 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)114 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
115 
116 namespace parsed {
117 
118 // ParseDataType has to be called first, then appropriate ParseZzzzList.
119 class Feature {
120  public:
Feature()121   Feature() {}
Feature(StringPiece serialized)122   explicit Feature(StringPiece serialized) : serialized_(serialized) {}
123 
ParseDataType(DataType * dtype)124   Status ParseDataType(DataType* dtype) {
125     DCHECK(dtype != nullptr);
126     if (serialized_.empty()) {
127       *dtype = DT_INVALID;
128       return Status::OK();
129     }
130     uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
131     serialized_.remove_prefix(1);
132     switch (oneof_tag) {
133       case kDelimitedTag(1):
134         *dtype = DT_STRING;
135         break;
136       case kDelimitedTag(2):
137         *dtype = DT_FLOAT;
138         break;
139       case kDelimitedTag(3):
140         *dtype = DT_INT64;
141         break;
142       default:
143         // Initialize variable to avoid compiler warning
144         *dtype = DT_INVALID;
145         return errors::InvalidArgument("Unsupported datatype.");
146     }
147     return Status::OK();
148   }
149 
GetNumElementsInBytesList(int * num_elements)150   bool GetNumElementsInBytesList(int* num_elements) {
151     protobuf::io::CodedInputStream stream(
152         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
153     EnableAliasing(&stream);
154     uint32 length = 0;
155     if (!stream.ReadVarint32(&length)) return false;
156     auto limit = stream.PushLimit(length);
157     *num_elements = 0;
158     while (!stream.ExpectAtEnd()) {
159       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
160       uint32 bytes_length = 0;
161       if (!stream.ReadVarint32(&bytes_length)) return false;
162       if (!stream.Skip(bytes_length)) return false;
163       ++*num_elements;
164     }
165     stream.PopLimit(limit);
166     return true;
167   }
168 
169   // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)170   tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
171     if (bytes_list->EndDistance() <= 0) {
172       return nullptr;
173     }
174     return &bytes_list->construct_at_end();
175   }
construct_at_end(SmallVector<tstring> * bytes_list)176   tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
177     return &bytes_list->emplace_back();
178   }
179 
180   template <typename Result>
ParseBytesList(Result * bytes_list)181   bool ParseBytesList(Result* bytes_list) {
182     DCHECK(bytes_list != nullptr);
183 
184     protobuf::io::CodedInputStream stream(
185         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
186 
187     EnableAliasing(&stream);
188 
189     uint32 length;
190     if (!stream.ReadVarint32(&length)) return false;
191     auto limit = stream.PushLimit(length);
192 
193     while (!stream.ExpectAtEnd()) {
194       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
195       // parse string
196       uint32 bytes_length;
197       if (!stream.ReadVarint32(&bytes_length)) return false;
198       tstring* bytes = construct_at_end(bytes_list);
199       if (bytes == nullptr) return false;
200       bytes->resize_uninitialized(bytes_length);
201       if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
202     }
203     stream.PopLimit(limit);
204     return true;
205   }
206 
207   template <typename Result>
ParseFloatList(Result * float_list)208   bool ParseFloatList(Result* float_list) {
209     DCHECK(float_list != nullptr);
210     protobuf::io::CodedInputStream stream(
211         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
212     EnableAliasing(&stream);
213     uint32 length;
214     if (!stream.ReadVarint32(&length)) return false;
215     auto limit = stream.PushLimit(length);
216 
217     if (!stream.ExpectAtEnd()) {
218       uint8 peek_tag = PeekTag(&stream);
219       if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
220         return false;
221       }
222 
223       constexpr int32 kNumFloatBytes = 4;
224       if (peek_tag == kDelimitedTag(1)) {                       // packed
225         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
226         uint32 packed_length;
227         if (!stream.ReadVarint32(&packed_length)) return false;
228         auto packed_limit = stream.PushLimit(packed_length);
229 
230         // Store the initial size to know the offset we have to start writing
231         // data from before resizing the output "vector".
232         const size_t initial_size = float_list->size();
233         float_list->resize(initial_size + packed_length / kNumFloatBytes);
234 
235         // If the result data type is float and we are on a little endian
236         // machine then we can simply memcpy the data from the proto into the
237         // result vector.
238         if (port::kLittleEndian &&
239             sizeof(typename Result::value_type) == kNumFloatBytes) {
240           // Calculate the length of the buffer available what can be less than
241           // what we requested in resize in case of a LimitedArraySlice.
242           const uint32 bytes_to_copy =
243               std::min(static_cast<uint32>((float_list->size() - initial_size) *
244                                            kNumFloatBytes),
245                        packed_length);
246           if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
247             return false;
248         } else {
249           int64 index = initial_size;
250           while (!stream.ExpectAtEnd()) {
251             uint32 buffer32;
252             if (!stream.ReadLittleEndian32(&buffer32)) return false;
253             if (index < float_list->size()) {
254               float_list->data()[index] = absl::bit_cast<float>(buffer32);
255               ++index;
256             }
257           }
258         }
259 
260         stream.PopLimit(packed_limit);
261       } else {  // non-packed
262         const size_t initial_size = float_list->size();
263         // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
264         // the value.
265         const int64 num_elements =
266             stream.BytesUntilLimit() / (1 + kNumFloatBytes);
267         float_list->resize(initial_size + num_elements);
268         int64 index = initial_size;
269         while (!stream.ExpectAtEnd()) {
270           if (!stream.ExpectTag(kFixed32Tag(1))) return false;
271           uint32 buffer32;
272           if (!stream.ReadLittleEndian32(&buffer32)) return false;
273           float_list->data()[index] = absl::bit_cast<float>(buffer32);
274           ++index;
275         }
276       }
277     }
278 
279     stream.PopLimit(limit);
280     return true;
281   }
282 
283   template <typename Result>
ParseInt64List(Result * int64_list)284   bool ParseInt64List(Result* int64_list) {
285     DCHECK(int64_list != nullptr);
286     protobuf::io::CodedInputStream stream(
287         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
288     EnableAliasing(&stream);
289     uint32 length;
290     if (!stream.ReadVarint32(&length)) return false;
291     auto limit = stream.PushLimit(length);
292 
293     if (!stream.ExpectAtEnd()) {
294       uint8 peek_tag = PeekTag(&stream);
295       if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
296         return false;
297       }
298       if (peek_tag == kDelimitedTag(1)) {                       // packed
299         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
300         uint32 packed_length;
301         if (!stream.ReadVarint32(&packed_length)) return false;
302         auto packed_limit = stream.PushLimit(packed_length);
303 
304         while (!stream.ExpectAtEnd()) {
305           protobuf_uint64 n;  // There is no API for int64
306           if (!stream.ReadVarint64(&n)) return false;
307           int64_list->push_back(static_cast<int64>(n));
308         }
309 
310         stream.PopLimit(packed_limit);
311       } else {  // non-packed
312         while (!stream.ExpectAtEnd()) {
313           if (!stream.ExpectTag(kVarintTag(1))) return false;
314           protobuf_uint64 n;  // There is no API for int64
315           if (!stream.ReadVarint64(&n)) return false;
316           int64_list->push_back(static_cast<int64>(n));
317         }
318       }
319     }
320     stream.PopLimit(limit);
321     return true;
322   }
323 
GetSerialized() const324   StringPiece GetSerialized() const { return serialized_; }
325 
326  private:
327   // TODO(lew): Pair of uint8* would be more natural.
328   StringPiece serialized_;
329 };
330 
331 using FeatureMapEntry = std::pair<StringPiece, Feature>;
332 using Example = std::vector<FeatureMapEntry>;
333 
334 }  // namespace parsed
335 
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)336 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
337   uint32 data;
338   protobuf_uint64 dummy;
339   switch (stream->ReadTag() & 0x7) {
340     case 0:  // varint
341       if (!stream->ReadVarint32(&data)) return false;
342       return true;
343     case 1:  // fixed64
344       if (!stream->ReadLittleEndian64(&dummy)) return false;
345       return true;
346     case 2:  // length delimited
347       if (!stream->ReadVarint32(&data)) return false;
348       stream->Skip(data);
349       return true;
350     case 3:          // group begin
351       return false;  // groups not supported.
352     case 4:          // group end
353       return false;  // groups not supported.
354     case 5:          // fixed32
355       if (!stream->ReadLittleEndian32(&data)) return false;
356       return true;
357   }
358   return false;  // unrecognized tag type
359 }
360 
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)361 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
362   DCHECK(stream != nullptr);
363   DCHECK(result != nullptr);
364   uint32 length;
365   if (!stream->ReadVarint32(&length)) return false;
366   if (length == 0) {
367     *result = StringPiece(nullptr, 0);
368     return true;
369   }
370   const void* stream_alias;
371   int stream_size;
372   if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
373     return false;
374   }
375   if (static_cast<uint32>(stream_size) < length) return false;
376   *result = StringPiece(static_cast<const char*>(stream_alias), length);
377   stream->Skip(length);
378   return true;
379 }
380 
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)381 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
382                           parsed::FeatureMapEntry* feature_map_entry) {
383   DCHECK(stream != nullptr);
384   DCHECK(feature_map_entry != nullptr);
385   uint32 length;
386   if (!stream->ReadVarint32(&length)) return false;
387   auto limit = stream->PushLimit(length);
388   if (!stream->ExpectTag(kDelimitedTag(1))) return false;
389   if (!ParseString(stream, &feature_map_entry->first)) return false;
390   if (!stream->ExpectTag(kDelimitedTag(2))) return false;
391   StringPiece feature_string_piece;
392   if (!ParseString(stream, &feature_string_piece)) return false;
393   feature_map_entry->second = parsed::Feature(feature_string_piece);
394   if (!stream->ExpectAtEnd()) return false;
395   stream->PopLimit(limit);
396   return true;
397 }
398 
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)399 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
400                    parsed::Example* example) {
401   DCHECK(stream != nullptr);
402   DCHECK(example != nullptr);
403   uint32 length;
404   if (!stream->ReadVarint32(&length)) return false;
405   auto limit = stream->PushLimit(length);
406   while (!stream->ExpectAtEnd()) {
407     parsed::FeatureMapEntry feature_map_entry;
408     if (!stream->ExpectTag(kDelimitedTag(1))) return false;
409     if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
410     example->push_back(std::move(feature_map_entry));
411   }
412   stream->PopLimit(limit);
413   return true;
414 }
415 
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)416 bool ParseExample(protobuf::io::CodedInputStream* stream,
417                   parsed::Example* example) {
418   DCHECK(stream != nullptr);
419   DCHECK(example != nullptr);
420   // Loop over the input stream which may contain multiple serialized Example
421   // protos merged together as strings. This behavior is consistent with Proto's
422   // ParseFromString when string representations are concatenated.
423   while (!stream->ExpectAtEnd()) {
424     if (!stream->ExpectTag(kDelimitedTag(1))) {
425       if (!SkipExtraneousTag(stream)) return false;
426     } else {
427       if (!ParseFeatures(stream, example)) return false;
428     }
429   }
430   return true;
431 }
432 
ParseExample(StringPiece serialized,parsed::Example * example)433 bool ParseExample(StringPiece serialized, parsed::Example* example) {
434   DCHECK(example != nullptr);
435   protobuf::io::CodedInputStream stream(
436       reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
437   EnableAliasing(&stream);
438   return ParseExample(&stream, example);
439 }
440 
441 }  // namespace
442 
TestFastParse(const string & serialized,Example * example)443 bool TestFastParse(const string& serialized, Example* example) {
444   DCHECK(example != nullptr);
445   parsed::Example parsed_example;
446   if (!ParseExample(serialized, &parsed_example)) return false;
447   auto& features = *example->mutable_features();
448   size_t parsed_example_size = parsed_example.size();
449   for (size_t i = 0; i < parsed_example_size; ++i) {
450     // This is a logic that standard protobuf parsing is implementing.
451     // I.e. last entry in the map overwrites all the previous ones.
452     parsed::FeatureMapEntry& name_and_feature =
453         parsed_example[parsed_example_size - i - 1];
454     string name(name_and_feature.first);
455     if ((*features.mutable_feature()).count(name) > 0) continue;
456 
457     auto& value = (*features.mutable_feature())[name];
458     DataType dtype;
459     if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false;
460     switch (dtype) {
461       case DT_INVALID:
462         break;
463       case DT_STRING: {
464         SmallVector<tstring> list;
465         if (!name_and_feature.second.ParseBytesList(&list)) return false;
466         auto* result_list = value.mutable_bytes_list();
467         for (auto& bytes : list) {
468           result_list->add_value(bytes.data(), bytes.size());
469         }
470         break;
471       }
472       case DT_FLOAT: {
473         SmallVector<float> list;
474         if (!name_and_feature.second.ParseFloatList(&list)) return false;
475         auto* result_list = value.mutable_float_list();
476         for (float f : list) {
477           result_list->add_value(f);
478         }
479         break;
480       }
481       case DT_INT64: {
482         SmallVector<int64> list;
483         if (!name_and_feature.second.ParseInt64List(&list)) return false;
484         auto* result_list = value.mutable_int64_list();
485         for (int64 i : list) {
486           result_list->add_value(i);
487         }
488         break;
489       }
490       default:
491         LOG(FATAL) << "Should not happen.";
492     }
493   }
494   return true;
495 }
496 
497 // -----------------------------------------------------------------------------
498 
499 namespace {
500 
501 using Config = FastParseExampleConfig;
502 
ParallelFor(const std::function<void (size_t)> & f,size_t n,thread::ThreadPool * thread_pool)503 void ParallelFor(const std::function<void(size_t)>& f, size_t n,
504                  thread::ThreadPool* thread_pool) {
505   if (n == 0) return;
506   if (thread_pool == nullptr) {
507     for (size_t i = 0; i < n; ++i) {
508       f(i);
509     }
510   } else {
511     BlockingCounter counter(n - 1);
512     for (size_t i = 1; i < n; ++i) {
513       thread_pool->Schedule([i, &f, &counter] {
514         f(i);
515         counter.DecrementCount();
516       });
517     }
518     f(0);
519     counter.Wait();
520   }
521 }
522 
523 // Enumeration for distinguishing feature types.
524 // Note: FastParseSequenceExample constructs a map that includes Type values,
525 // and relies on the fact that they are default-initialized to Dense.
526 enum class Type { Dense, Sparse, Ragged };
527 
528 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
529 struct SparseBuffer {
530   // Features are in one of the 3 vectors below depending on config's dtype.
531   // Other 2 vectors remain empty.
532   SmallVector<tstring> bytes_list;
533   SmallVector<float> float_list;
534   SmallVector<int64> int64_list;
535 
536   // Features of example i are elements with indices
537   // from example_end_indices[i-1] to example_end_indices[i]-1 on the
538   // appropriate xxxxx_list
539   std::vector<size_t> example_end_indices;
540 };
541 
542 struct SeededHasher {
operator ()tensorflow::example::__anon4c07364d0211::SeededHasher543   uint64 operator()(StringPiece s) const {
544     return Hash64(s.data(), s.size(), seed);
545   }
546   uint64 seed{0xDECAFCAFFE};
547 };
548 
LogDenseFeatureDataLoss(StringPiece feature_name)549 void LogDenseFeatureDataLoss(StringPiece feature_name) {
550   LOG(WARNING) << "Data loss! Feature '" << feature_name
551                << "' is present in multiple concatenated "
552                   "tf.Examples. Ignoring all but last one.";
553   static auto* duplicated_dense_feature = monitoring::Counter<0>::New(
554       "/tensorflow/core/util/example_proto_fast_parsing/"
555       "duplicated_dense_feature",
556       "Dense feature appears twice in a tf.Example");
557   duplicated_dense_feature->GetCell()->IncrementBy(1);
558 }
559 
LogSparseFeatureDataLoss(StringPiece feature_name)560 void LogSparseFeatureDataLoss(StringPiece feature_name) {
561   LOG(WARNING) << "Data loss! Feature '" << feature_name
562                << "' is present in multiple concatenated "
563                   "tf.Examples. Ignoring all but last one.";
564   static auto* duplicated_sparse_feature = monitoring::Counter<0>::New(
565       "/tensorflow/core/util/example_proto_fast_parsing/"
566       "duplicated_sparse_feature",
567       "Sparse feature appears twice in a tf.Example");
568   duplicated_sparse_feature->GetCell()->IncrementBy(1);
569 }
570 
FastParseSerializedExample(const tstring & serialized_example,const tstring & example_name,const size_t example_index,const Config & config,const PresizedCuckooMap<std::pair<size_t,Type>> & config_index,SeededHasher hasher,std::vector<Tensor> * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,std::vector<SparseBuffer> * output_ragged,PerExampleFeatureStats * output_stats)571 Status FastParseSerializedExample(
572     const tstring& serialized_example, const tstring& example_name,
573     const size_t example_index, const Config& config,
574     const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
575     SeededHasher hasher, std::vector<Tensor>* output_dense,
576     std::vector<SparseBuffer>* output_varlen_dense,
577     std::vector<SparseBuffer>* output_sparse,
578     std::vector<SparseBuffer>* output_ragged,
579     PerExampleFeatureStats* output_stats) {
580   DCHECK(output_dense != nullptr);
581   DCHECK(output_sparse != nullptr);
582   DCHECK(output_ragged != nullptr);
583   parsed::Example parsed_example;
584   if (!ParseExample(serialized_example, &parsed_example)) {
585     return errors::InvalidArgument("Could not parse example input, value: '",
586                                    serialized_example, "'");
587   }
588   std::vector<int64> sparse_feature_last_example(config.sparse.size(), -1);
589   std::vector<int64> dense_feature_last_example(config.dense.size(), -1);
590   std::vector<int64> ragged_feature_last_example(config.ragged.size(), -1);
591 
592   // Handle features present in the example.
593   const size_t parsed_example_size = parsed_example.size();
594 
595   if (output_stats) {
596     // TODO(b/111553342): This may over-count the number of features if there
597     // are duplicate keys in the feature map. Consider deduplicating the keys
598     // before computing the count.
599     output_stats->features_count = parsed_example_size;
600   }
601 
602   for (size_t i = 0; i < parsed_example_size; ++i) {
603     // This is a logic that standard protobuf parsing is implementing.
604     // I.e. last entry in the map overwrites all the previous ones.
605     parsed::FeatureMapEntry& name_and_feature =
606         parsed_example[parsed_example_size - i - 1];
607 
608     const StringPiece feature_name = name_and_feature.first;
609     parsed::Feature& feature = name_and_feature.second;
610 
611     std::pair<size_t, Type> d_and_type;
612     uint64 h = hasher(feature_name);
613     if (!config_index.Find(h, &d_and_type)) continue;
614 
615     size_t d = d_and_type.first;
616     bool is_dense = d_and_type.second == Type::Dense;
617     bool is_ragged = d_and_type.second == Type::Ragged;
618 
619     {
620       // Testing for PresizedCuckooMap collision.
621       // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
622       const tstring& config_feature_name =
623           is_dense ? config.dense[d].feature_name
624                    : (is_ragged ? config.ragged[d].feature_name
625                                 : config.sparse[d].feature_name);
626       if (feature_name != config_feature_name) continue;
627     }
628 
629     auto example_error = [&](StringPiece suffix) {
630       return errors::InvalidArgument("Name: ", example_name,
631                                      ", Key: ", feature_name,
632                                      ", Index: ", example_index, ".  ", suffix);
633     };
634 
635     auto parse_error = [&] {
636       return example_error("Can't parse serialized Example.");
637     };
638 
639     DataType example_dtype;
640     TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
641 
642     if (is_dense) {
643       if (example_dtype == DT_INVALID) continue;
644 
645       // If feature was already visited, skip.
646       // Compare comment at the beginning of the loop.
647       if (dense_feature_last_example[d] == example_index) {
648         LogDenseFeatureDataLoss(feature_name);
649         continue;
650       }
651       dense_feature_last_example[d] = example_index;
652 
653       if (example_dtype != config.dense[d].dtype) {
654         return example_error(strings::StrCat(
655             "Data types don't match. Data type: ",
656             DataTypeString(example_dtype),
657             " but expected type: ", DataTypeString(config.dense[d].dtype)));
658       }
659       if (!config.dense[d].variable_length) {
660         Tensor& out = (*output_dense)[d];
661 
662         const std::size_t num_elements = config.dense[d].elements_per_stride;
663         if (output_stats) {
664           // TODO(b/111553342): If desirable, we could add support for counting
665           // elements in the features that aren't parsed, but this could add
666           // considerable runtime cost.
667           output_stats->feature_values_count += num_elements;
668         }
669 
670         const std::size_t offset = example_index * num_elements;
671 
672         auto shape_error = [&](size_t size, StringPiece type_str) {
673           return example_error(strings::StrCat(
674               "Number of ", type_str,
675               " values != expected.  "
676               "Values size: ",
677               size,
678               " but output shape: ", config.dense[d].shape.DebugString()));
679         };
680 
681         switch (config.dense[d].dtype) {
682           case DT_INT64: {
683             auto out_p = out.flat<int64>().data() + offset;
684             LimitedArraySlice<int64> slice(out_p, num_elements);
685             if (!feature.ParseInt64List(&slice)) return parse_error();
686             if (slice.EndDistance() != 0) {
687               return shape_error(num_elements - slice.EndDistance(), "int64");
688             }
689             break;
690           }
691           case DT_FLOAT: {
692             auto out_p = out.flat<float>().data() + offset;
693             LimitedArraySlice<float> slice(out_p, num_elements);
694             if (!feature.ParseFloatList(&slice)) return parse_error();
695             if (slice.EndDistance() != 0) {
696               return shape_error(num_elements - slice.EndDistance(), "float");
697             }
698             break;
699           }
700           case DT_STRING: {
701             auto out_p = out.flat<tstring>().data() + offset;
702             LimitedArraySlice<tstring> slice(out_p, num_elements);
703             if (!feature.ParseBytesList(&slice)) return parse_error();
704             if (slice.EndDistance() != 0) {
705               return shape_error(num_elements - slice.EndDistance(), "bytes");
706             }
707             break;
708           }
709           default:
710             LOG(FATAL) << "Should not happen.";
711         }
712       } else {  // if variable length
713         SparseBuffer& out = (*output_varlen_dense)[d];
714 
715         const std::size_t num_elements = config.dense[d].elements_per_stride;
716 
717         if (example_dtype != DT_INVALID &&
718             example_dtype != config.dense[d].dtype) {
719           return example_error(strings::StrCat(
720               "Data types don't match. ",
721               "Expected type: ", DataTypeString(config.dense[d].dtype)));
722         }
723 
724         auto shape_error = [&](size_t size, StringPiece type_str) {
725           return example_error(strings::StrCat(
726               "Number of ", type_str,
727               " values is not a multiple of stride length. Saw ", size,
728               " values but output shape is: ",
729               config.dense[d].shape.DebugString()));
730         };
731 
732         switch (config.dense[d].dtype) {
733           case DT_INT64: {
734             if (example_dtype != DT_INVALID) {
735               if (!feature.ParseInt64List(&out.int64_list)) {
736                 return parse_error();
737               }
738               if (out.int64_list.size() % num_elements != 0) {
739                 return shape_error(out.int64_list.size(), "int64");
740               }
741             }
742             out.example_end_indices.push_back(out.int64_list.size());
743             break;
744           }
745           case DT_FLOAT: {
746             if (example_dtype != DT_INVALID) {
747               if (!feature.ParseFloatList(&out.float_list)) {
748                 return parse_error();
749               }
750               if (out.float_list.size() % num_elements != 0) {
751                 return shape_error(out.float_list.size(), "float");
752               }
753             }
754             out.example_end_indices.push_back(out.float_list.size());
755             break;
756           }
757           case DT_STRING: {
758             if (example_dtype != DT_INVALID) {
759               if (!feature.ParseBytesList(&out.bytes_list)) {
760                 return parse_error();
761               }
762               if (out.bytes_list.size() % num_elements != 0) {
763                 return shape_error(out.bytes_list.size(), "bytes");
764               }
765             }
766             out.example_end_indices.push_back(out.bytes_list.size());
767             break;
768           }
769           default:
770             LOG(FATAL) << "Should not happen.";
771         }
772 
773         if (output_stats) {
774           // Use `out.example_end_indices` to determine the feature-value count
775           // for this feature, because the preceding switch statement pushes
776           // the length of the appropriate feature list to that vector.
777           // TODO(b/111553342): If desirable, we could add support for counting
778           // elements in the features that aren't parsed, but this could add
779           // considerable runtime cost.
780           const size_t out_examples_count = out.example_end_indices.size();
781           if (out_examples_count == 1) {
782             output_stats->feature_values_count += out.example_end_indices[0];
783           } else {
784             output_stats->feature_values_count +=
785                 out.example_end_indices[out_examples_count - 1] -
786                 out.example_end_indices[out_examples_count - 2];
787           }
788         }
789       }
790     } else {
791       // Feature is sparse or ragged.
792       auto& last_example =
793           is_ragged ? ragged_feature_last_example : sparse_feature_last_example;
794 
795       // If feature was already visited, skip.
796       // Compare comment at the beginning of the loop.
797       if (last_example[d] == example_index) {
798         LogSparseFeatureDataLoss(feature_name);
799         continue;
800       }
801       last_example[d] = example_index;
802 
803       // Handle sparse features.
804       SparseBuffer& out = is_ragged ? (*output_ragged)[d] : (*output_sparse)[d];
805       DataType feature_dtype =
806           is_ragged ? config.ragged[d].dtype : config.sparse[d].dtype;
807       if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
808         return example_error(
809             strings::StrCat("Data types don't match. ",
810                             "Expected type: ", DataTypeString(feature_dtype),
811                             ", Actual type: ", DataTypeString(example_dtype)));
812       }
813 
814       switch (feature_dtype) {
815         case DT_INT64: {
816           if (example_dtype != DT_INVALID) {
817             if (!feature.ParseInt64List(&out.int64_list)) {
818               return parse_error();
819             }
820           }
821           out.example_end_indices.push_back(out.int64_list.size());
822           break;
823         }
824         case DT_FLOAT: {
825           if (example_dtype != DT_INVALID) {
826             if (!feature.ParseFloatList(&out.float_list)) {
827               return parse_error();
828             }
829           }
830           out.example_end_indices.push_back(out.float_list.size());
831           break;
832         }
833         case DT_STRING: {
834           if (example_dtype != DT_INVALID) {
835             if (!feature.ParseBytesList(&out.bytes_list)) {
836               return parse_error();
837             }
838           }
839           out.example_end_indices.push_back(out.bytes_list.size());
840           break;
841         }
842         default:
843           LOG(FATAL) << "Should not happen.";
844       }
845 
846       if (output_stats) {
847         // Use `out.example_end_indices` to determine the feature-value count
848         // for this feature, because the preceding switch statement pushes
849         // the length of the appropriate feature list to that vector.
850         // TODO(b/111553342): If desirable, we could add support for counting
851         // elements in the features that aren't parsed, but this could add
852         // considerable runtime cost.
853         const size_t out_examples_count = out.example_end_indices.size();
854         if (out_examples_count == 1) {
855           output_stats->feature_values_count += out.example_end_indices[0];
856         } else {
857           output_stats->feature_values_count +=
858               out.example_end_indices[out_examples_count - 1] -
859               out.example_end_indices[out_examples_count - 2];
860         }
861       }
862     }
863   }
864 
865   // Handle missing dense features for fixed strides.
866   for (size_t d = 0; d < config.dense.size(); ++d) {
867     if (config.dense[d].variable_length) continue;
868     if (dense_feature_last_example[d] == example_index) continue;
869     if (config.dense[d].default_value.NumElements() == 0) {
870       return errors::InvalidArgument(
871           "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
872           " (data type: ", DataTypeString(config.dense[d].dtype), ")",
873           " is required but could not be found.");
874     }
875     const Tensor& in = config.dense[d].default_value;
876     Tensor& out = (*output_dense)[d];
877     const std::size_t num_elements = in.shape().num_elements();
878     const std::size_t offset = example_index * num_elements;
879 
880     switch (config.dense[d].dtype) {
881       case DT_INT64: {
882         std::copy_n(in.flat<int64>().data(), num_elements,
883                     out.flat<int64>().data() + offset);
884         break;
885       }
886       case DT_FLOAT: {
887         std::copy_n(in.flat<float>().data(), num_elements,
888                     out.flat<float>().data() + offset);
889         break;
890       }
891       case DT_STRING: {
892         std::copy_n(in.flat<tstring>().data(), num_elements,
893                     out.flat<tstring>().data() + offset);
894         break;
895       }
896       default:
897         LOG(FATAL) << "Should not happen.";
898     }
899   }
900 
901   // Handle missing varlen dense features.
902   for (size_t d = 0; d < config.dense.size(); ++d) {
903     if (!config.dense[d].variable_length) continue;
904     if (dense_feature_last_example[d] == example_index) continue;
905     SparseBuffer& out = (*output_varlen_dense)[d];
906     size_t prev_example_end_index =
907         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
908     out.example_end_indices.push_back(prev_example_end_index);
909   }
910 
911   // Handle missing sparse features.
912   for (size_t d = 0; d < config.sparse.size(); ++d) {
913     if (sparse_feature_last_example[d] == example_index) continue;
914     SparseBuffer& out = (*output_sparse)[d];
915     size_t prev_example_end_index =
916         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
917     out.example_end_indices.push_back(prev_example_end_index);
918   }
919 
920   // Handle missing ragged features.
921   for (size_t d = 0; d < config.ragged.size(); ++d) {
922     if (ragged_feature_last_example[d] == example_index) continue;
923     SparseBuffer& out = (*output_ragged)[d];
924     size_t prev_example_end_index =
925         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
926     out.example_end_indices.push_back(prev_example_end_index);
927   }
928 
929   return Status::OK();
930 }
931 
CheckConfigDataType(DataType dtype)932 Status CheckConfigDataType(DataType dtype) {
933   switch (dtype) {
934     case DT_INT64:
935     case DT_FLOAT:
936     case DT_STRING:
937       return Status::OK();
938     default:
939       return errors::InvalidArgument("Invalid config dtype: ",
940                                      DataTypeString(dtype));
941   }
942 }
943 
944 // Use this in the "default" clause of switch statements when dispatching
945 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)946 inline void ReportUnexpectedDataType(DataType dtype) {
947   DCHECK(false)
948       << "Encountered unexpected DataType " << DataTypeString(dtype)
949       << "in variable that should have been checked by CheckConfigDataType().";
950 }
951 
CheckConfigDataTypes(const Config & config)952 Status CheckConfigDataTypes(const Config& config) {
953   // Check config so we can safely CHECK(false) in switches on config.*.dtype
954   for (auto& c : config.sparse) {
955     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
956   }
957   for (auto& c : config.dense) {
958     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
959   }
960   for (auto& c : config.ragged) {
961     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
962     if (!(c.splits_dtype == DT_INT32 || c.splits_dtype == DT_INT64)) {
963       return errors::InvalidArgument("Invalid ragged_split_type: ",
964                                      DataTypeString(c.splits_dtype));
965     }
966   }
967   return Status::OK();
968 }
969 
970 template <typename T>
971 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
972 
973 template <>
GetListFromBuffer(const SparseBuffer & buffer)974 const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
975   return buffer.int64_list;
976 }
977 template <>
GetListFromBuffer(const SparseBuffer & buffer)978 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
979   return buffer.float_list;
980 }
981 template <>
GetListFromBuffer(const SparseBuffer & buffer)982 const SmallVector<tstring>& GetListFromBuffer<tstring>(
983     const SparseBuffer& buffer) {
984   return buffer.bytes_list;
985 }
986 
987 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)988 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
989   std::copy(b, e, t);
990 }
991 template <>
CopyOrMoveBlock(const tstring * b,const tstring * e,tstring * t)992 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
993   std::move(b, e, t);
994 }
995 
996 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const Config & config,const std::vector<std::vector<SparseBuffer>> & varlen_dense_buffers,Tensor * values)997 void FillAndCopyVarLen(
998     const int d, const size_t num_elements,
999     const size_t num_elements_per_minibatch, const Config& config,
1000     const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers,
1001     Tensor* values) {
1002   const Tensor& default_value = config.dense[d].default_value;
1003 
1004   // Copy-fill the tensors (creating the zero/fill-padding)
1005   std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements,
1006             default_value.flat<T>()(0));
1007 
1008   // Data is [batch_size, max_num_elements, data_stride_size]
1009   //   and num_elements_per_minibatch = max_num_elements * data_stride_size
1010   auto data = values->flat<T>().data();
1011 
1012   // Iterate over minibatch elements
1013   for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) {
1014     const SparseBuffer& buffer = varlen_dense_buffers[i][d];
1015     // Number of examples being stored in this buffer
1016     const auto& end_indices = buffer.example_end_indices;
1017     const size_t examples_in_buffer = end_indices.size();
1018     // const size_t stride_size = config.dense[d].elements_per_stride;
1019 
1020     const auto& list = GetListFromBuffer<T>(buffer);
1021     auto list_ptr = list.begin();
1022 
1023     size_t elements_tally = 0;
1024     // Iterate through all the examples stored in this buffer.
1025     for (size_t j = 0; j < examples_in_buffer; ++j) {
1026       // Number of elements stored for this example.
1027       const size_t num_elems = end_indices[j] - elements_tally;
1028       CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
1029       // Move forward this many elements in the varlen buffer.
1030       list_ptr += num_elems;
1031       // Move forward to the next minibatch entry in the values output.
1032       data += num_elements_per_minibatch;
1033       elements_tally = end_indices[j];
1034     }
1035     DCHECK(elements_tally == list.size());
1036   }
1037 }
1038 
1039 // Thin vector like interface wrapper around a Tensor. This enable us to
1040 // directly populate a tensor during parsing instead of having to first create a
1041 // vactor and then copy the data over.
1042 template <typename T>
1043 class TensorVector {
1044  public:
1045   using value_type = T;
1046 
tensor()1047   const Tensor& tensor() {
1048     if (!tensor_.has_value()) {
1049       resize(0);
1050     }
1051     return *tensor_;
1052   }
1053 
size() const1054   int64 size() const {
1055     return tensor_.has_value() ? tensor_->NumElements() : 0;
1056   }
resize(int64 new_size)1057   void resize(int64 new_size) {
1058     DCHECK(!tensor_.has_value());
1059     tensor_ = Tensor(DataTypeToEnum<T>::v(), TensorShape({new_size}));
1060     data_ = tensor_->flat<T>().data();
1061   }
data()1062   T* data() { return data_; }
data() const1063   const T* data() const { return data_; }
1064 
1065  private:
1066   // Use absl::optional to avoid calling the default constructor of Tensor
1067   // unnecessarily.
1068   absl::optional<Tensor> tensor_;
1069 
1070   // Cached pointer to the raw data inside the tensor.
1071   T* data_ = nullptr;
1072 };
1073 
CountSparseFeatures(const std::vector<std::vector<SparseBuffer>> & sparse_buffers,size_t d,size_t * total_num_features,size_t * max_num_features)1074 void CountSparseFeatures(
1075     const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
1076     size_t* total_num_features, size_t* max_num_features) {
1077   for (auto& sparse_values_tmp : sparse_buffers) {
1078     const std::vector<size_t>& end_indices =
1079         sparse_values_tmp[d].example_end_indices;
1080     *total_num_features += end_indices.back();
1081     *max_num_features = std::max(*max_num_features, end_indices[0]);
1082     for (size_t i = 1; i < end_indices.size(); ++i) {
1083       size_t example_size = end_indices[i] - end_indices[i - 1];
1084       *max_num_features = std::max(*max_num_features, example_size);
1085     }
1086   }
1087 }
1088 
CopySparseBufferToTensor(DataType dtype,size_t offset,SparseBuffer * src,Tensor * dst)1089 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
1090                               Tensor* dst) {
1091   switch (dtype) {
1092     case DT_INT64: {
1093       std::copy(src->int64_list.begin(), src->int64_list.end(),
1094                 dst->flat<int64>().data() + offset);
1095       break;
1096     }
1097     case DT_FLOAT: {
1098       std::copy(src->float_list.begin(), src->float_list.end(),
1099                 dst->flat<float>().data() + offset);
1100       break;
1101     }
1102     case DT_STRING: {
1103       std::move(src->bytes_list.begin(), src->bytes_list.end(),
1104                 dst->flat<tstring>().data() + offset);
1105       break;
1106     }
1107     default:
1108       ReportUnexpectedDataType(dtype);
1109   }
1110 }
1111 
1112 }  // namespace
1113 
FastParseExample(const Config & config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * result)1114 Status FastParseExample(const Config& config,
1115                         gtl::ArraySlice<tstring> serialized,
1116                         gtl::ArraySlice<tstring> example_names,
1117                         thread::ThreadPool* thread_pool, Result* result) {
1118   DCHECK(result != nullptr);
1119   // Check config so we can safely CHECK(false) in switches on config.*.dtype
1120   TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1121 
1122   if (config.collect_feature_stats) {
1123     result->feature_stats.resize(serialized.size());
1124   }
1125 
1126   size_t config_size =
1127       config.dense.size() + config.sparse.size() + config.ragged.size();
1128   SeededHasher hasher;
1129   // Build config index.
1130   PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1131   bool ok = true;
1132   for (size_t i = 0; i < 1000; ++i) {
1133     for (size_t d = 0; d < config.dense.size(); ++d) {
1134       ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1135                                       {d, Type::Dense});
1136     }
1137     for (size_t d = 0; d < config.sparse.size(); ++d) {
1138       ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1139                                       {d, Type::Sparse});
1140     }
1141     for (size_t d = 0; d < config.ragged.size(); ++d) {
1142       ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1143                                       {d, Type::Ragged});
1144     }
1145     if (ok) break;
1146     LOG(WARNING) << "Collision found. This should happen only if you have "
1147                     "around 2^32 entries in your config.";
1148     hasher.seed++;
1149     config_index.Clear(config_size);
1150     ok = true;
1151   }
1152   if (!ok) {
1153     return errors::Internal(
1154         "Could not avoid collision. This should not happen.");
1155   }
1156 
1157   // Allocate dense output for fixed length dense values
1158   // (variable-length dense and sparse and ragged have to be buffered).
1159   std::vector<Tensor> fixed_dense_values(config.dense.size());
1160   for (size_t d = 0; d < config.dense.size(); ++d) {
1161     if (config.dense[d].variable_length) continue;
1162     TensorShape out_shape;
1163     out_shape.AddDim(serialized.size());
1164     for (const int64 dim : config.dense[d].shape.dim_sizes()) {
1165       out_shape.AddDim(dim);
1166     }
1167     fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape);
1168   }
1169 
1170   // This parameter affects performance in a big and data-dependent way.
1171   const size_t kMiniBatchSizeBytes = 50000;
1172 
1173   // Calculate number of minibatches.
1174   // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1175   // Apply 'special logic' below for small and big regimes.
1176   const size_t num_minibatches = [&] {
1177     size_t result = 0;
1178     size_t minibatch_bytes = 0;
1179     for (size_t i = 0; i < serialized.size(); i++) {
1180       if (minibatch_bytes == 0) {  // start minibatch
1181         result++;
1182       }
1183       minibatch_bytes += serialized[i].size() + 1;
1184       if (minibatch_bytes > kMiniBatchSizeBytes) {
1185         minibatch_bytes = 0;
1186       }
1187     }
1188     // 'special logic'
1189     const size_t min_minibatches = std::min<size_t>(8, serialized.size());
1190     const size_t max_minibatches = 64;
1191     return std::max<size_t>(min_minibatches,
1192                             std::min<size_t>(max_minibatches, result));
1193   }();
1194 
1195   auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
1196     return (serialized.size() * minibatch) / num_minibatches;
1197   };
1198 
1199   // TODO(lew): A big performance low-hanging fruit here is to improve
1200   //   num_minibatches calculation to take into account actual amount of work
1201   //   needed, as the size in bytes is not perfect. Linear combination of
1202   //   size in bytes and average number of features per example is promising.
1203   //   Even better: measure time instead of estimating, but this is too costly
1204   //   in small batches.
1205   //   Maybe accept outside parameter #num_minibatches?
1206 
1207   // Do minibatches in parallel.
1208   std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
1209   std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches);
1210   std::vector<std::vector<SparseBuffer>> ragged_buffers(num_minibatches);
1211   std::vector<Status> status_of_minibatch(num_minibatches);
1212   auto ProcessMiniBatch = [&](size_t minibatch) {
1213     sparse_buffers[minibatch].resize(config.sparse.size());
1214     varlen_dense_buffers[minibatch].resize(config.dense.size());
1215     ragged_buffers[minibatch].resize(config.ragged.size());
1216     size_t start = first_example_of_minibatch(minibatch);
1217     size_t end = first_example_of_minibatch(minibatch + 1);
1218     for (size_t e = start; e < end; ++e) {
1219       PerExampleFeatureStats* stats = nullptr;
1220       if (config.collect_feature_stats) {
1221         stats = &result->feature_stats[e];
1222       }
1223       status_of_minibatch[minibatch] = FastParseSerializedExample(
1224           serialized[e],
1225           (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
1226           config_index, hasher, &fixed_dense_values,
1227           &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch],
1228           &ragged_buffers[minibatch], stats);
1229       if (!status_of_minibatch[minibatch].ok()) break;
1230     }
1231   };
1232 
1233   ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool);
1234 
1235   for (Status& status : status_of_minibatch) {
1236     TF_RETURN_IF_ERROR(status);
1237   }
1238 
1239   result->sparse_indices.reserve(config.sparse.size());
1240   result->sparse_values.reserve(config.sparse.size());
1241   result->sparse_shapes.reserve(config.sparse.size());
1242   result->dense_values.reserve(config.dense.size());
1243   result->ragged_values.reserve(config.ragged.size());
1244   result->ragged_splits.reserve(config.ragged.size());
1245 
1246   for (size_t d = 0; d < config.dense.size(); ++d) {
1247     result->dense_values.push_back(std::move(fixed_dense_values[d]));
1248   }
1249 
1250   // Merge SparseBuffers from all minibatches for every config.sparse.
1251   auto MergeSparseMinibatches = [&](size_t d) {
1252     // Loop over minibatches
1253     size_t total_num_features = 0;
1254     size_t max_num_features = 0;
1255     CountSparseFeatures(sparse_buffers, d, &total_num_features,
1256                         &max_num_features);
1257 
1258     TensorShape indices_shape;
1259     indices_shape.AddDim(total_num_features);
1260     indices_shape.AddDim(2);
1261     result->sparse_indices.emplace_back(DT_INT64, indices_shape);
1262     Tensor* indices = &result->sparse_indices.back();
1263 
1264     TensorShape values_shape;
1265     values_shape.AddDim(total_num_features);
1266     result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape);
1267     Tensor* values = &result->sparse_values.back();
1268 
1269     result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2}));
1270     auto shapes_shape_t = result->sparse_shapes.back().vec<int64>();
1271     shapes_shape_t(0) = serialized.size();
1272     shapes_shape_t(1) = max_num_features;
1273 
1274     size_t offset = 0;
1275     for (size_t i = 0; i < sparse_buffers.size(); ++i) {
1276       SparseBuffer& buffer = sparse_buffers[i][d];
1277 
1278       // Update indices.
1279       size_t delta = 0;
1280 
1281       if (indices->NumElements() > 0) {
1282         int64* ix_p = &indices->matrix<int64>()(offset, 0);
1283         size_t example_index = first_example_of_minibatch(i);
1284         for (size_t example_end_index : buffer.example_end_indices) {
1285           size_t feature_index = 0;
1286           for (; delta < example_end_index; ++delta) {
1287             // Column 0: example index
1288             *ix_p = example_index;
1289             // Column 1: the feature index buffer example
1290             *(ix_p + 1) = feature_index;
1291             ix_p += 2;
1292             ++feature_index;
1293           }
1294           ++example_index;
1295         }
1296       }
1297 
1298       CopySparseBufferToTensor(config.sparse[d].dtype, offset, &buffer, values);
1299       offset += delta;
1300     }
1301   };
1302 
1303   // Merge SparseBuffers from all minibatches for every config.ragged.
1304   auto MergeRaggedMinibatches = [&](size_t d) {
1305     // Loop over minibatches
1306     size_t total_num_features = 0;
1307     size_t max_num_features = 0;
1308     CountSparseFeatures(ragged_buffers, d, &total_num_features,
1309                         &max_num_features);
1310 
1311     TensorShape row_splits_shape;
1312     row_splits_shape.AddDim(serialized.size() + 1);
1313     result->ragged_splits.emplace_back(config.ragged[d].splits_dtype,
1314                                        row_splits_shape);
1315     Tensor* row_splits = &result->ragged_splits.back();
1316     if (config.ragged[d].splits_dtype == DT_INT64) {
1317       row_splits->flat<int64>()(0) = 0;
1318     } else {
1319       row_splits->flat<int32>()(0) = 0;
1320     }
1321 
1322     TensorShape values_shape;
1323     values_shape.AddDim(total_num_features);
1324     result->ragged_values.emplace_back(config.ragged[d].dtype, values_shape);
1325     Tensor* values = &result->ragged_values.back();
1326 
1327     size_t values_offset = 0;
1328     size_t splits_offset = 0;
1329     for (size_t i = 0; i < ragged_buffers.size(); ++i) {
1330       SparseBuffer& buffer = ragged_buffers[i][d];
1331       if (buffer.example_end_indices.empty()) continue;
1332 
1333       // Update row_splits.  row_splits are formed by concatenating the example
1334       // end_indices (adjusting each to start after the previous one ends).
1335       if (config.ragged[d].splits_dtype == DT_INT64) {
1336         int64* row_splits_out = &row_splits->flat<int64>()(splits_offset);
1337         int64 start = *row_splits_out;
1338         for (size_t example_end_index : buffer.example_end_indices) {
1339           *++row_splits_out = start + example_end_index;
1340         }
1341       } else {
1342         int32* row_splits_out = &row_splits->flat<int32>()(splits_offset);
1343         int32 start = *row_splits_out;
1344         for (size_t example_end_index : buffer.example_end_indices) {
1345           *++row_splits_out = start + example_end_index;
1346         }
1347       }
1348 
1349       CopySparseBufferToTensor(config.ragged[d].dtype, values_offset, &buffer,
1350                                values);
1351       values_offset += buffer.example_end_indices.back();
1352       splits_offset += buffer.example_end_indices.size();
1353     }
1354   };
1355 
1356   // Merge SparseBuffers from all minibatches for every config.dense having
1357   // variable_length.
1358   auto MergeDenseVarLenMinibatches = [&](size_t d) {
1359     if (!config.dense[d].variable_length) return;
1360 
1361     // Loop over minibatches
1362     size_t max_num_features = 0;
1363     for (auto& dense_values_tmp : varlen_dense_buffers) {
1364       std::vector<size_t>& end_indices =
1365           dense_values_tmp[d].example_end_indices;
1366       max_num_features = std::max(max_num_features, end_indices[0]);
1367       for (size_t i = 1; i < end_indices.size(); ++i) {
1368         size_t example_size = end_indices[i] - end_indices[i - 1];
1369         max_num_features = std::max(max_num_features, example_size);
1370       }
1371     }
1372 
1373     const size_t stride_size = config.dense[d].elements_per_stride;
1374     const size_t max_num_elements = max_num_features / stride_size;
1375     TensorShape values_shape;
1376     DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
1377     const size_t batch_size = serialized.size();
1378     values_shape.AddDim(batch_size);
1379     values_shape.AddDim(max_num_elements);
1380     for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1381       values_shape.AddDim(config.dense[d].shape.dim_size(i));
1382     }
1383     Tensor values(config.dense[d].dtype, values_shape);
1384     result->dense_values[d] = values;
1385     const size_t num_elements = values.NumElements();
1386 
1387     // Nothing to write, exit early.
1388     if (num_elements == 0) return;
1389 
1390     const size_t num_elements_per_minibatch = num_elements / batch_size;
1391 
1392     switch (config.dense[d].dtype) {
1393       case DT_INT64: {
1394         FillAndCopyVarLen<int64>(d, num_elements, num_elements_per_minibatch,
1395                                  config, varlen_dense_buffers, &values);
1396         break;
1397       }
1398       case DT_FLOAT: {
1399         FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
1400                                  config, varlen_dense_buffers, &values);
1401         break;
1402       }
1403       case DT_STRING: {
1404         FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch,
1405                                    config, varlen_dense_buffers, &values);
1406         break;
1407       }
1408       default:
1409         ReportUnexpectedDataType(config.dense[d].dtype);
1410     }
1411   };
1412 
1413   for (size_t d = 0; d < config.dense.size(); ++d) {
1414     MergeDenseVarLenMinibatches(d);
1415   }
1416 
1417   for (size_t d = 0; d < config.sparse.size(); ++d) {
1418     MergeSparseMinibatches(d);
1419   }
1420 
1421   for (size_t d = 0; d < config.ragged.size(); ++d) {
1422     MergeRaggedMinibatches(d);
1423   }
1424 
1425   return Status::OK();
1426 }
1427 
FastParseSingleExample(const Config & config,StringPiece serialized,Result * result)1428 Status FastParseSingleExample(const Config& config, StringPiece serialized,
1429                               Result* result) {
1430   DCHECK(result != nullptr);
1431   // Check config so we can safely CHECK(false) in switches on config.*.dtype
1432   TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1433 
1434   PerExampleFeatureStats* stats = nullptr;
1435   if (config.collect_feature_stats) {
1436     result->feature_stats.emplace_back();
1437     stats = &result->feature_stats.back();
1438   }
1439 
1440   // TODO(mrry): Cache the construction of this map at Op construction time.
1441   size_t config_size =
1442       config.dense.size() + config.sparse.size() + config.ragged.size();
1443   SeededHasher hasher;
1444   // Build config index.
1445   PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1446   bool ok = true;
1447   for (size_t i = 0; i < 1000; ++i) {
1448     for (size_t d = 0; d < config.dense.size(); ++d) {
1449       ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1450                                       {d, Type::Dense});
1451     }
1452     for (size_t d = 0; d < config.sparse.size(); ++d) {
1453       ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1454                                       {d, Type::Sparse});
1455     }
1456     for (size_t d = 0; d < config.ragged.size(); ++d) {
1457       ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1458                                       {d, Type::Ragged});
1459     }
1460     if (ok) break;
1461     LOG(WARNING) << "Collision found. This should happen only if you have "
1462                     "around 2^32 entries in your config.";
1463     hasher.seed++;
1464     config_index.Clear(config_size);
1465     ok = true;
1466   }
1467   if (!ok) {
1468     return errors::Internal(
1469         "Could not avoid collision. This should not happen.");
1470   }
1471 
1472   result->sparse_indices.reserve(config.sparse.size());
1473   result->sparse_values.reserve(config.sparse.size());
1474   result->sparse_shapes.reserve(config.sparse.size());
1475   result->dense_values.reserve(config.dense.size());
1476   result->ragged_values.reserve(config.ragged.size());
1477   result->ragged_splits.reserve(config.ragged.size());
1478 
1479   // Allocate dense output tensors.
1480   for (size_t d = 0; d < config.dense.size(); ++d) {
1481     if (!config.dense[d].variable_length) {
1482       TensorShape values_shape;
1483       if (!config.dense[d].shape.AsTensorShape(&values_shape)) {
1484         return errors::Internal(
1485             "Fixed-length shape was not a statically defined shape.");
1486       }
1487       result->dense_values.emplace_back(config.dense[d].dtype, values_shape);
1488     } else {
1489       // Variable-length tensor will be allocated later.
1490       result->dense_values.emplace_back();
1491     }
1492   }
1493 
1494   // Allocate sparse output tensors.
1495   for (size_t d = 0; d < config.sparse.size(); ++d) {
1496     // The dense_shape is always a vector of length 1.
1497     result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1}));
1498     // Variable-length tensors will be allocated later.
1499     result->sparse_indices.emplace_back();
1500     result->sparse_values.emplace_back();
1501   }
1502 
1503   // Allocate ragged output tensors.
1504   for (size_t d = 0; d < config.ragged.size(); ++d) {
1505     // Variable-length values tensors will be allocated later.
1506     result->ragged_values.emplace_back();
1507     // Splits tensors are empty (unused) for single (scalar) inputs.
1508     const auto splits_dtype = config.ragged[d].splits_dtype;
1509     result->ragged_splits.emplace_back(splits_dtype, TensorShape({0}));
1510   }
1511 
1512   parsed::Example parsed_example;
1513   if (!ParseExample(serialized, &parsed_example)) {
1514     return errors::InvalidArgument("Could not parse example input, value: '",
1515                                    serialized, "'");
1516   }
1517   std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false);
1518   std::vector<bool> dense_feature_already_seen(config.dense.size(), false);
1519   std::vector<bool> ragged_feature_already_seen(config.ragged.size(), false);
1520 
1521   if (stats) {
1522     // TODO(b/111553342): This may over-count the number of features if there
1523     // are duplicate keys in the feature map. Consider deduplicating the keys
1524     // before computing the count.
1525     stats->features_count = parsed_example.size();
1526   }
1527 
1528   // Handle features present in the example.
1529   const size_t parsed_example_size = parsed_example.size();
1530   for (size_t i = 0; i < parsed_example_size; ++i) {
1531     // This is a logic that standard protobuf parsing is implementing.
1532     // I.e. last entry in the map overwrites all the previous ones.
1533     parsed::FeatureMapEntry& name_and_feature =
1534         parsed_example[parsed_example_size - i - 1];
1535 
1536     const StringPiece feature_name = name_and_feature.first;
1537     parsed::Feature& feature = name_and_feature.second;
1538 
1539     std::pair<size_t, Type> d_and_type;
1540     uint64 h = hasher(feature_name);
1541     if (!config_index.Find(h, &d_and_type)) continue;
1542 
1543     size_t d = d_and_type.first;
1544     bool is_dense = d_and_type.second == Type::Dense;
1545     bool is_sparse = d_and_type.second == Type::Sparse;
1546 
1547     {
1548       // Testing for PresizedCuckooMap collision.
1549       // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
1550       const tstring& config_feature_name =
1551           is_dense ? config.dense[d].feature_name
1552                    : (is_sparse ? config.sparse[d].feature_name
1553                                 : config.ragged[d].feature_name);
1554       if (feature_name != config_feature_name) continue;
1555     }
1556 
1557     auto example_error = [feature_name](StringPiece suffix) {
1558       return errors::InvalidArgument("Key: ", feature_name, ".  ", suffix);
1559     };
1560 
1561     auto parse_error = [feature_name] {
1562       return errors::InvalidArgument("Key: ", feature_name,
1563                                      ".  Can't parse serialized Example.");
1564     };
1565 
1566     DataType example_dtype;
1567     TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
1568     if (example_dtype == DT_INVALID) continue;
1569 
1570     if (is_dense && !config.dense[d].variable_length) {
1571       // If feature was already visited, skip.
1572       // Compare comment at the beginning of the loop.
1573       if (dense_feature_already_seen[d]) {
1574         LogDenseFeatureDataLoss(feature_name);
1575         continue;
1576       }
1577       dense_feature_already_seen[d] = true;
1578 
1579       if (example_dtype != config.dense[d].dtype) {
1580         return example_error(strings::StrCat(
1581             "Data types don't match. Data type: ",
1582             DataTypeString(example_dtype),
1583             " but expected type: ", DataTypeString(config.dense[d].dtype)));
1584       }
1585 
1586       Tensor* out = &result->dense_values[d];
1587       const std::size_t num_elements = config.dense[d].elements_per_stride;
1588       if (stats) {
1589         // TODO(b/111553342): If desirable, we could add support for counting
1590         // elements in the features that aren't parsed, but this could add
1591         // considerable runtime cost.
1592         stats->feature_values_count += num_elements;
1593       }
1594       switch (example_dtype) {
1595         case DT_INT64: {
1596           auto out_p = out->flat<int64>().data();
1597           LimitedArraySlice<int64> slice(out_p, num_elements);
1598           if (!feature.ParseInt64List(&slice)) return parse_error();
1599           if (slice.EndDistance() != 0) {
1600             return parse_error();
1601           }
1602           break;
1603         }
1604         case DT_FLOAT: {
1605           auto out_p = out->flat<float>().data();
1606           LimitedArraySlice<float> slice(out_p, num_elements);
1607           if (!feature.ParseFloatList(&slice)) return parse_error();
1608           if (slice.EndDistance() != 0) {
1609             return parse_error();
1610           }
1611           break;
1612         }
1613         case DT_STRING: {
1614           auto out_p = out->flat<tstring>().data();
1615           LimitedArraySlice<tstring> slice(out_p, num_elements);
1616           if (!feature.ParseBytesList(&slice)) return parse_error();
1617           if (slice.EndDistance() != 0) {
1618             return parse_error();
1619           }
1620           break;
1621         }
1622         default:
1623           ReportUnexpectedDataType(example_dtype);
1624       }
1625 
1626     } else {  // if variable length
1627       SmallVector<tstring> bytes_list;
1628       TensorVector<float> float_list;
1629       SmallVector<int64> int64_list;
1630 
1631       const size_t num_elements_divisor =
1632           is_dense ? config.dense[d].elements_per_stride : 1;
1633       size_t num_elements;
1634 
1635       if (is_dense) {
1636         // If feature was already visited, skip.
1637         // Compare comment at the beginning of the loop.
1638         if (dense_feature_already_seen[d]) {
1639           LogDenseFeatureDataLoss(feature_name);
1640           continue;
1641         }
1642         dense_feature_already_seen[d] = true;
1643         if (example_dtype != config.dense[d].dtype) {
1644           return example_error(strings::StrCat(
1645               "Data types don't match. Data type: ",
1646               DataTypeString(example_dtype),
1647               " but expected type: ", DataTypeString(config.dense[d].dtype)));
1648         }
1649       } else {
1650         // Feature is sparse or ragged.
1651         auto& feature_already_seen = is_sparse ? sparse_feature_already_seen
1652                                                : ragged_feature_already_seen;
1653         auto& feature_dtype =
1654             is_sparse ? config.sparse[d].dtype : config.ragged[d].dtype;
1655         // If feature was already visited, skip.
1656         // Compare comment at the beginning of the loop.
1657         if (feature_already_seen[d]) {
1658           LogSparseFeatureDataLoss(feature_name);
1659           continue;
1660         }
1661         feature_already_seen[d] = true;
1662 
1663         // Handle sparse features.
1664         if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
1665           return example_error(strings::StrCat(
1666               "Data types don't match. ",
1667               "Expected type: ", DataTypeString(feature_dtype),
1668               ", Actual type: ", DataTypeString(example_dtype)));
1669         }
1670       }
1671 
1672       switch (example_dtype) {
1673         case DT_INT64: {
1674           // TODO(mrry): Use the fact that the `int64_list` is packed to read
1675           // out the length and pre-allocate the output tensor.
1676           if (!feature.ParseInt64List(&int64_list)) return parse_error();
1677           num_elements = int64_list.size();
1678           break;
1679         }
1680         case DT_FLOAT: {
1681           if (!feature.ParseFloatList(&float_list)) return parse_error();
1682           num_elements = float_list.size();
1683           break;
1684         }
1685         case DT_STRING: {
1686           int actual_num_elements = 0;
1687           if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1688             return parse_error();
1689           }
1690           bytes_list.reserve(actual_num_elements);
1691           if (!feature.ParseBytesList(&bytes_list)) return parse_error();
1692           num_elements = bytes_list.size();
1693           break;
1694         }
1695         default:
1696           num_elements = 0;
1697           ReportUnexpectedDataType(example_dtype);
1698       }
1699 
1700       if (num_elements % num_elements_divisor != 0) {
1701         return parse_error();
1702       }
1703 
1704       if (stats) {
1705         stats->feature_values_count += num_elements;
1706       }
1707 
1708       Tensor* out;
1709       DataType out_dtype;
1710       TensorShape out_shape;
1711       if (is_dense) {
1712         out_shape.AddDim(num_elements / num_elements_divisor);
1713         for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1714           out_shape.AddDim(config.dense[d].shape.dim_size(i));
1715         }
1716 
1717         out = &result->dense_values[d];
1718         out_dtype = config.dense[d].dtype;
1719       } else if (is_sparse) {
1720         Tensor* out_indices = &result->sparse_indices[d];
1721         Tensor* out_dense_shape = &result->sparse_shapes[d];
1722 
1723         // TODO(mrry): Investigate the possibility of not materializing
1724         // the indices (and perhaps dense_shape) until they are needed.
1725         *out_indices = Tensor(
1726             DT_INT64, TensorShape({static_cast<int64>(num_elements), 1}));
1727         auto indices_flat = out_indices->flat<int64>();
1728         for (size_t i = 0; i < num_elements; ++i) {
1729           indices_flat(i) = static_cast<int64>(i);
1730         }
1731 
1732         *out_dense_shape = Tensor(DT_INT64, TensorShape({1}));
1733         auto shapes_shape_t = out_dense_shape->vec<int64>();
1734         shapes_shape_t(0) = num_elements;
1735 
1736         out = &result->sparse_values[d];
1737         out_dtype = config.sparse[d].dtype;
1738         out_shape.AddDim(num_elements);
1739       } else {
1740         out = &result->ragged_values[d];
1741         out_dtype = config.ragged[d].dtype;
1742         out_shape.AddDim(num_elements);
1743       }
1744 
1745       switch (example_dtype) {
1746         case DT_INT64: {
1747           *out = Tensor(out_dtype, out_shape);
1748           CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
1749                           out->flat<int64>().data());
1750           break;
1751         }
1752         case DT_FLOAT: {
1753           if (!out->CopyFrom(float_list.tensor(), out_shape)) {
1754             return parse_error();
1755           }
1756           break;
1757         }
1758         case DT_STRING: {
1759           *out = Tensor(out_dtype, out_shape);
1760           CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(),
1761                           out->flat<tstring>().data());
1762           break;
1763         }
1764         default:
1765           ReportUnexpectedDataType(example_dtype);
1766       }
1767     }
1768   }
1769 
1770   // Handle missing dense features.
1771   for (size_t d = 0; d < config.dense.size(); ++d) {
1772     if (!dense_feature_already_seen[d]) {
1773       if (!config.dense[d].variable_length) {
1774         // Handle missing fixed-length dense feature.
1775         if (config.dense[d].default_value.NumElements() == 0) {
1776           return errors::InvalidArgument(
1777               "Feature: ", config.dense[d].feature_name,
1778               " (data type: ", DataTypeString(config.dense[d].dtype), ")",
1779               " is required but could not be found.");
1780         }
1781         result->dense_values[d] = config.dense[d].default_value;
1782       } else {
1783         // Handle missing varlen dense feature.
1784         TensorShape empty_shape;
1785         empty_shape.AddDim(0);
1786         for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1787           empty_shape.AddDim(config.dense[d].shape.dim_size(i));
1788         }
1789         result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape);
1790       }
1791     }
1792   }
1793 
1794   // Handle missing sparse features.
1795   for (size_t d = 0; d < config.sparse.size(); ++d) {
1796     if (!sparse_feature_already_seen[d]) {
1797       result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1}));
1798       result->sparse_values[d] =
1799           Tensor(config.sparse[d].dtype, TensorShape({0}));
1800       result->sparse_shapes[d].vec<int64>()(0) = 0;
1801     }
1802   }
1803 
1804   // Handle missing ragged features.
1805   for (size_t d = 0; d < config.ragged.size(); ++d) {
1806     if (!ragged_feature_already_seen[d]) {
1807       result->ragged_values[d] =
1808           Tensor(config.ragged[d].dtype, TensorShape({0}));
1809     }
1810   }
1811 
1812   return Status::OK();
1813 }
1814 
1815 // Private helper functions for FastParseSequenceExample.
1816 namespace {
1817 
1818 // A struct used by FastParseSequenceExample to hold the serialized proto
1819 // substrings for a single feature, plus some auxiliary information derived
1820 // from those protos (such as the total value length).
1821 struct FeatureProtos {
1822   // Proto substrings from each serialized SequenceExample that correspond
1823   // with this feature.  `protos_present` records whether the proto had a
1824   // value defined (even if that value is empty).
1825   std::vector<StringPiece> protos;
1826   std::vector<bool> protos_present;
1827 
1828   // Information derived from protos:
1829   size_t length;    // total length for ragged/sparse, max row length for dense.
1830   size_t num_rows;  // only populated for ragged sequence features.
1831 
1832   // Information from the config:
1833   Type type;  // Whether this feature is sparse, ragged, or dense.
1834   DataType dtype;
1835 };
1836 
1837 // Map from feature name to FeatureProtos for that feature.
1838 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
1839 
ExampleName(const gtl::ArraySlice<tstring> example_names,int n)1840 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
1841   return example_names.empty() ? "<unknown>" : example_names[n];
1842 }
1843 
1844 // Return the number of bytes elements parsed, or -1 on error. If out is null,
1845 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)1846 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
1847                              tstring* out) {
1848   int num_elements = 0;
1849   uint32 length;
1850   if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
1851     return -1;
1852   }
1853   if (length > 0) {
1854     auto limit = stream->PushLimit(length);
1855     while (!stream->ExpectAtEnd()) {
1856       uint32 bytes_length;
1857       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1858           !stream->ReadVarint32(&bytes_length)) {
1859         return -1;
1860       }
1861       if (out == nullptr) {
1862         stream->Skip(bytes_length);
1863       } else {
1864         out->resize_uninitialized(bytes_length);
1865         if (!stream->ReadRaw(out->data(), bytes_length)) {
1866           return -1;
1867         }
1868         out++;
1869       }
1870       num_elements++;
1871     }
1872     stream->PopLimit(limit);
1873   }
1874   return num_elements;
1875 }
1876 
PadFloatFeature(int num_to_pad,float * out)1877 inline void PadFloatFeature(int num_to_pad, float* out) {
1878   for (int i = 0; i < num_to_pad; i++) {
1879     *out++ = 0.0;
1880   }
1881 }
1882 
PadInt64Feature(int num_to_pad,int64 * out)1883 inline void PadInt64Feature(int num_to_pad, int64* out) {
1884   for (int i = 0; i < num_to_pad; i++) {
1885     *out++ = 0;
1886   }
1887 }
1888 
1889 // Return the number of float elements parsed, or -1 on error. If out is null,
1890 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)1891 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
1892                              float* out) {
1893   int num_elements = 0;
1894   uint32 length;
1895   if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
1896     return -1;
1897   }
1898   if (length > 0) {
1899     auto limit = stream->PushLimit(length);
1900     uint8 peek_tag = PeekTag(stream);
1901     if (peek_tag == kDelimitedTag(1)) {  // packed
1902       uint32 packed_length;
1903       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1904           !stream->ReadVarint32(&packed_length)) {
1905         return -1;
1906       }
1907       auto packed_limit = stream->PushLimit(packed_length);
1908       while (!stream->ExpectAtEnd()) {
1909         uint32 buffer32;
1910         if (!stream->ReadLittleEndian32(&buffer32)) {
1911           return -1;
1912         }
1913         if (out != nullptr) {
1914           *out++ = absl::bit_cast<float>(buffer32);
1915         }
1916         num_elements++;
1917       }
1918       stream->PopLimit(packed_limit);
1919     } else if (peek_tag == kFixed32Tag(1)) {
1920       while (!stream->ExpectAtEnd()) {
1921         uint32 buffer32;
1922         if (!stream->ExpectTag(kFixed32Tag(1)) ||
1923             !stream->ReadLittleEndian32(&buffer32)) {
1924           return -1;
1925         }
1926         if (out != nullptr) {
1927           *out++ = absl::bit_cast<float>(buffer32);
1928         }
1929         num_elements++;
1930       }
1931     } else {
1932       // Unknown tag.
1933       return -1;
1934     }
1935     stream->PopLimit(limit);
1936   }
1937   return num_elements;
1938 }
1939 
1940 // Return the number of int64 elements parsed, or -1 on error. If out is null,
1941 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64 * out)1942 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
1943                              int64* out) {
1944   int num_elements = 0;
1945   uint32 length;
1946   if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
1947     return -1;
1948   }
1949   if (length > 0) {
1950     auto limit = stream->PushLimit(length);
1951     uint8 peek_tag = PeekTag(stream);
1952     if (peek_tag == kDelimitedTag(1)) {  // packed
1953       uint32 packed_length;
1954       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1955           !stream->ReadVarint32(&packed_length)) {
1956         return -1;
1957       }
1958       auto packed_limit = stream->PushLimit(packed_length);
1959       while (!stream->ExpectAtEnd()) {
1960         protobuf_uint64 n;  // There is no API for int64
1961         if (!stream->ReadVarint64(&n)) {
1962           return -1;
1963         }
1964         if (out != nullptr) {
1965           *out++ = n;
1966         }
1967         num_elements++;
1968       }
1969       stream->PopLimit(packed_limit);
1970     } else if (peek_tag == kVarintTag(1)) {
1971       while (!stream->ExpectAtEnd()) {
1972         protobuf_uint64 n;  // There is no API for int64
1973         if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
1974           return -1;
1975         }
1976         if (out != nullptr) {
1977           *out++ = n;
1978         }
1979         num_elements++;
1980       }
1981     } else {
1982       // Unknown tag.
1983       return -1;
1984     }
1985     stream->PopLimit(limit);
1986   }
1987   return num_elements;
1988 }
1989 
1990 // Parses the next feature on `stream` into `out` starting at `out_offset`.
1991 // Updates `out_offset`, and returns the number of values added.
1992 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)1993 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
1994                         Tensor* out, size_t* out_offset) {
1995   int delta;
1996   switch (dtype) {
1997     case DT_STRING:
1998       delta =
1999           ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
2000       break;
2001     case DT_FLOAT:
2002       delta =
2003           ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
2004       break;
2005     case DT_INT64:
2006       delta =
2007           ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
2008       break;
2009     default:
2010       ReportUnexpectedDataType(dtype);
2011       delta = 0;
2012   }
2013   if (delta > 0) {
2014     *out_offset += delta;
2015   }
2016   return delta;
2017 }
2018 
2019 // Returns the length of the next feature on `stream`.
2020 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)2021 inline int GetFeatureLength(DataType dtype,
2022                             protobuf::io::CodedInputStream* stream) {
2023   switch (dtype) {
2024     case DT_STRING:
2025       return ParseBytesFeature(stream, nullptr);
2026     case DT_FLOAT:
2027       return ParseFloatFeature(stream, nullptr);
2028     case DT_INT64:
2029       return ParseInt64Feature(stream, nullptr);
2030     default:
2031       ReportUnexpectedDataType(dtype);
2032       return -1;
2033   }
2034 }
2035 
ParseDataType(protobuf::io::CodedInputStream * stream)2036 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
2037   uint8 peek_tag = PeekTag(stream);
2038   switch (peek_tag) {
2039     case kDelimitedTag(1):
2040       return DT_STRING;
2041     case kDelimitedTag(2):
2042       return DT_FLOAT;
2043     case kDelimitedTag(3):
2044       return DT_INT64;
2045     default:
2046       return DT_INVALID;
2047   }
2048 }
2049 
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)2050 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
2051                              DataType dtype) {
2052   switch (dtype) {
2053     case DT_STRING:
2054       if (!stream->ExpectTag(kDelimitedTag(1))) {
2055         return false;
2056       }
2057       break;
2058     case DT_FLOAT:
2059       if (!stream->ExpectTag(kDelimitedTag(2))) {
2060         return false;
2061       }
2062       break;
2063     case DT_INT64:
2064       if (!stream->ExpectTag(kDelimitedTag(3))) {
2065         return false;
2066       }
2067       break;
2068     default:
2069       return false;
2070   }
2071   uint32 length;
2072   return stream->ReadVarint32(&length) && length == 0;
2073 }
2074 
2075 // Reads an example proto, and extracts a StringPiece pointer to each feature.
ExtractFeaturesFromSequenceExamples(const gtl::ArraySlice<tstring> examples,const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features,FeatureProtosMap * sequence_features)2076 Status ExtractFeaturesFromSequenceExamples(
2077     const gtl::ArraySlice<tstring> examples,
2078     const gtl::ArraySlice<tstring> example_names,
2079     FeatureProtosMap* context_features, FeatureProtosMap* sequence_features) {
2080   for (int d = 0; d < examples.size(); d++) {
2081     const tstring& example = examples[d];
2082     protobuf::io::CodedInputStream stream(
2083         reinterpret_cast<const uint8*>(example.data()), example.size());
2084     // Not clear what this does. Why not stream.EnableAliasing()?
2085     EnableAliasing(&stream);
2086 
2087     // Extract pointers to all features within this serialized example.
2088     while (!stream.ExpectAtEnd()) {
2089       FeatureProtosMap* features = nullptr;
2090       if (stream.ExpectTag(kDelimitedTag(1))) {
2091         // Context
2092         features = context_features;
2093       } else if (stream.ExpectTag(kDelimitedTag(2))) {
2094         // Sequence
2095         features = sequence_features;
2096       } else if (!SkipExtraneousTag(&stream)) {
2097         return errors::InvalidArgument(
2098             "Invalid protocol message input, example id: ",
2099             ExampleName(example_names, d));
2100       }
2101       if (features != nullptr) {
2102         uint32 length;
2103         if (!stream.ReadVarint32(&length)) {
2104           return errors::InvalidArgument(
2105               "Invalid protocol message input, example id: ",
2106               ExampleName(example_names, d));
2107         }
2108         auto limit = stream.PushLimit(length);
2109         while (!stream.ExpectAtEnd()) {
2110           StringPiece key, value;
2111           uint32 length;
2112           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2113               !stream.ReadVarint32(&length)) {
2114             return errors::InvalidArgument(
2115                 "Invalid protocol message input, example id: ",
2116                 ExampleName(example_names, d));
2117           }
2118           auto limit = stream.PushLimit(length);
2119           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2120               !ParseString(&stream, &key) ||
2121               !stream.ExpectTag(kDelimitedTag(2)) ||
2122               !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
2123             return errors::InvalidArgument(
2124                 "Invalid protocol message input, example id: ",
2125                 ExampleName(example_names, d));
2126           }
2127           stream.PopLimit(limit);
2128           // Only save if this feature was requested.
2129           auto feature_iter = features->find(key);
2130           if (feature_iter != features->end()) {
2131             auto& feature = feature_iter->second;
2132             feature.protos[d] = value;
2133             feature.protos_present[d] = true;
2134           }
2135         }
2136         stream.PopLimit(limit);
2137       }
2138     }
2139   }
2140   return Status::OK();
2141 }
2142 
2143 // Populates context_features[k].length based on context_features[k].protos
2144 // (for all k).
GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features)2145 Status GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2146                                 FeatureProtosMap* context_features) {
2147   for (auto& c : *context_features) {
2148     FeatureProtos& feature = c.second;
2149     for (int d = 0; d < feature.protos.size(); ++d) {
2150       const auto& proto = feature.protos[d];
2151       if (proto.empty()) continue;
2152       protobuf::io::CodedInputStream stream(
2153           reinterpret_cast<const uint8*>(proto.data()), proto.size());
2154       EnableAliasing(&stream);
2155       int num_elements = GetFeatureLength(feature.dtype, &stream);
2156       if (num_elements < 0) {
2157         return errors::InvalidArgument(
2158             "Name: ", ExampleName(example_names, d),
2159             ", Context feature: ", c.first,
2160             ".  Data types don't match. Expected type: ",
2161             DataTypeString(feature.dtype));
2162       }
2163       switch (feature.type) {
2164         case Type::Sparse:  // intentional fall-through
2165         case Type::Ragged:
2166           feature.length += num_elements;
2167           break;
2168         case Type::Dense:
2169           feature.length =
2170               std::max(feature.length, static_cast<size_t>(num_elements));
2171           break;
2172       }
2173     }
2174   }
2175   return Status::OK();
2176 }
2177 
2178 // Populates sequence_features[k].length and sequence_features[k].num_rows based
2179 // on sequence_features[k].protos (for all k).
GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * sequence_features)2180 Status GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2181                                  FeatureProtosMap* sequence_features) {
2182   for (auto& c : *sequence_features) {
2183     FeatureProtos& feature = c.second;
2184     for (int d = 0; d < feature.protos.size(); ++d) {
2185       const auto& proto = feature.protos[d];
2186       if (proto.empty()) continue;
2187 
2188       size_t num_rows = 0;
2189       size_t num_elements = 0;
2190       protobuf::io::CodedInputStream stream(
2191           reinterpret_cast<const uint8*>(proto.data()), proto.size());
2192       EnableAliasing(&stream);
2193       while (!stream.ExpectAtEnd()) {
2194         uint32 feature_bytes;
2195         if (!stream.ExpectTag(kDelimitedTag(1)) ||
2196             !stream.ReadVarint32(&feature_bytes)) {
2197           return errors::InvalidArgument("Error in sequence feature ", c.first,
2198                                          " in example ",
2199                                          ExampleName(example_names, d));
2200         }
2201         if (feature_bytes > 2) {
2202           auto limit = stream.PushLimit(feature_bytes);
2203           int delta = GetFeatureLength(feature.dtype, &stream);
2204           if (delta < 0) {
2205             return errors::InvalidArgument(
2206                 "Name: ", ExampleName(example_names, d),
2207                 ", Feature list: ", c.first, ", Index: ", num_rows,
2208                 ".  Data types don't match. Expected type: ",
2209                 DataTypeString(feature.dtype));
2210           }
2211           num_elements += delta;
2212           stream.PopLimit(limit);
2213         } else if (feature_bytes == 2) {
2214           if (!SkipEmptyFeature(&stream, feature.dtype)) {
2215             return errors::InvalidArgument(
2216                 "Name: ", ExampleName(example_names, d),
2217                 ", Feature list: ", c.first, ", Index: ", num_rows,
2218                 ".  Data types don't match. Expected type: ",
2219                 DataTypeString(feature.dtype));
2220           }
2221         } else if (feature_bytes != 0) {
2222           return errors::InvalidArgument("Error in sequence feature ", c.first,
2223                                          " in example ",
2224                                          ExampleName(example_names, d));
2225         }
2226         ++num_rows;
2227       }
2228       switch (feature.type) {
2229         case Type::Sparse:
2230           feature.length += num_elements;
2231           break;
2232         case Type::Ragged:
2233           feature.length += num_elements;
2234           feature.num_rows += num_rows;
2235           break;
2236         case Type::Dense:
2237           feature.length = std::max(feature.length, num_elements);
2238           break;
2239       }
2240     }
2241   }
2242   return Status::OK();
2243 }
2244 
2245 // Copies src into dst[dst_offset:dst_offset+src.size], and then increments
2246 // dst_offset by src.size.
CopyTensorIntoTensor(DataType dtype,const Tensor & src,Tensor * dst,size_t * dst_offset)2247 void CopyTensorIntoTensor(DataType dtype, const Tensor& src, Tensor* dst,
2248                           size_t* dst_offset) {
2249   size_t src_size = src.NumElements();
2250   switch (dtype) {
2251     case DT_INT64: {
2252       auto src_t = src.flat<int64>().data();
2253       std::copy(src_t, src_t + src_size,
2254                 dst->flat<int64>().data() + *dst_offset);
2255       break;
2256     }
2257     case DT_FLOAT: {
2258       auto src_t = src.flat<float>().data();
2259       std::copy(src_t, src_t + src_size,
2260                 dst->flat<float>().data() + *dst_offset);
2261       break;
2262     }
2263     case DT_STRING: {
2264       auto src_t = src.flat<tstring>().data();
2265       std::copy(src_t, src_t + src_size,
2266                 dst->flat<tstring>().data() + *dst_offset);
2267       break;
2268     }
2269     default:
2270       ReportUnexpectedDataType(dtype);
2271   }
2272   *dst_offset += src_size;
2273 }
2274 
2275 // Parses dense features in `context_features`, and writes their parsed
2276 // values to `context_results`.
ParseContextDenseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2277 Status ParseContextDenseFeatures(const FeatureProtosMap& context_features,
2278                                  const FastParseExampleConfig& context_config,
2279                                  gtl::ArraySlice<tstring> example_names,
2280                                  bool is_batch, int num_examples,
2281                                  Allocator* allocator, Result* context_result) {
2282   for (int t = 0; t < context_config.dense.size(); ++t) {
2283     const auto& c = context_config.dense[t];
2284     const FeatureProtos& feature =
2285         context_features.find(c.feature_name)->second;
2286     TensorShape dense_shape, example_shape;
2287     DataType dtype = c.dtype;
2288     const size_t data_max_elements = feature.length;
2289     if (!c.shape.AsTensorShape(&example_shape) ||
2290         data_max_elements != example_shape.num_elements()) {
2291       return errors::InvalidArgument(
2292           "Inconsistent max number of elements for feature ", c.feature_name,
2293           ": expected ", example_shape.num_elements(), ", but found ",
2294           data_max_elements);
2295     }
2296     if (is_batch) {
2297       dense_shape.AddDim(num_examples);
2298     }
2299     for (const int dim : c.shape.dim_sizes()) {
2300       dense_shape.AddDim(dim);
2301     }
2302     context_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2303 
2304     Tensor& out = context_result->dense_values[t];
2305     size_t out_offset = 0;
2306 
2307     // Fill in the values.
2308     for (int e = 0; e < num_examples; e++) {
2309       size_t num_elements = 0;
2310       const auto& feature_proto = feature.protos[e];
2311       if (!feature.protos_present[e]) {
2312         // Copy the default value, if present. If not, return an error.
2313         if (c.default_value.NumElements() == 0) {
2314           return errors::InvalidArgument(
2315               "Feature: ", c.feature_name,
2316               " (data type: ", DataTypeString(c.dtype), ")",
2317               " is required but could not be found.");
2318         }
2319         CopyTensorIntoTensor(dtype, c.default_value, &out, &out_offset);
2320         num_elements += c.default_value.NumElements();
2321       } else if (!feature_proto.empty()) {
2322         protobuf::io::CodedInputStream stream(
2323             reinterpret_cast<const uint8*>(feature_proto.data()),
2324             feature_proto.size());
2325         EnableAliasing(&stream);
2326         num_elements += ParseFeature(dtype, &stream, &out, &out_offset);
2327       }
2328       if (num_elements != data_max_elements) {
2329         return errors::InvalidArgument(
2330             "Unexpected number of elements in example ",
2331             ExampleName(example_names, e));
2332       }
2333     }
2334   }
2335   return Status::OK();
2336 }
2337 
2338 // Parses sparse features in `context_features`, and writes their parsed
2339 // values to `context_results`.
ParseContextSparseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2340 Status ParseContextSparseFeatures(const FeatureProtosMap& context_features,
2341                                   const FastParseExampleConfig& context_config,
2342                                   gtl::ArraySlice<tstring> example_names,
2343                                   bool is_batch, int num_examples,
2344                                   Allocator* allocator,
2345                                   Result* context_result) {
2346   for (int t = 0; t < context_config.sparse.size(); ++t) {
2347     const auto& c = context_config.sparse[t];
2348     const FeatureProtos& feature =
2349         context_features.find(c.feature_name)->second;
2350     TensorShape indices_shape, values_shape;
2351     DataType dtype = c.dtype;
2352     size_t expected_num_elements = feature.length;
2353     indices_shape.AddDim(expected_num_elements);
2354     indices_shape.AddDim(is_batch ? 2 : 1);
2355     values_shape.AddDim(expected_num_elements);
2356     context_result->sparse_indices[t] =
2357         Tensor(allocator, DT_INT64, indices_shape);
2358     context_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2359     context_result->sparse_shapes[t] =
2360         Tensor(allocator, DT_INT64, TensorShape({is_batch ? 2 : 1}));
2361     Tensor& out_values = context_result->sparse_values[t];
2362     size_t out_values_offset = 0;
2363     int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
2364     auto out_shape = context_result->sparse_shapes[t].vec<int64>();
2365 
2366     // Fill in the values.
2367     size_t num_elements = 0;
2368     size_t max_num_cols = 0;
2369     for (int e = 0; e < num_examples; e++) {
2370       const auto& feature_proto = feature.protos[e];
2371       if (feature_proto.empty()) continue;
2372       protobuf::io::CodedInputStream stream(
2373           reinterpret_cast<const uint8*>(feature_proto.data()),
2374           feature_proto.size());
2375       EnableAliasing(&stream);
2376       size_t num_added =
2377           ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2378       num_elements += num_added;
2379       max_num_cols = std::max(max_num_cols, num_added);
2380       for (int i = 0; i < num_added; i++) {
2381         if (is_batch) *out_indices++ = e;
2382         *out_indices++ = i;
2383       }
2384     }
2385     if (num_elements != expected_num_elements) {
2386       return errors::InvalidArgument(
2387           "Unexpected total number of elements in feature ", c.feature_name);
2388     }
2389     if (is_batch) {
2390       out_shape(0) = num_examples;
2391       out_shape(1) = max_num_cols;
2392     } else {
2393       out_shape(0) = max_num_cols;
2394     }
2395   }
2396   return Status::OK();
2397 }
2398 
2399 // Parses ragged features in `context_features`, and writes their parsed
2400 // values to `context_results`.
ParseContextRaggedFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2401 Status ParseContextRaggedFeatures(const FeatureProtosMap& context_features,
2402                                   const FastParseExampleConfig& context_config,
2403                                   gtl::ArraySlice<tstring> example_names,
2404                                   bool is_batch, int num_examples,
2405                                   Allocator* allocator,
2406                                   Result* context_result) {
2407   for (int t = 0; t < context_config.ragged.size(); ++t) {
2408     const auto& c = context_config.ragged[t];
2409     const FeatureProtos& feature =
2410         context_features.find(c.feature_name)->second;
2411     TensorShape values_shape, splits_shape;
2412     DataType dtype = c.dtype;
2413     DataType splits_dtype = c.splits_dtype;
2414     size_t expected_num_elements = feature.length;
2415     values_shape.AddDim(expected_num_elements);
2416     if (is_batch) {
2417       splits_shape.AddDim(num_examples + 1);
2418     }
2419     context_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2420     context_result->ragged_splits[t] =
2421         Tensor(allocator, splits_dtype, splits_shape);
2422     Tensor& out_values = context_result->ragged_values[t];
2423     size_t out_values_offset = 0;
2424     int32* int32_splits =
2425         is_batch && splits_dtype == DT_INT32
2426             ? context_result->ragged_splits[t].vec<int32>().data()
2427             : nullptr;
2428     int64* int64_splits =
2429         is_batch && splits_dtype == DT_INT64
2430             ? context_result->ragged_splits[t].vec<int64>().data()
2431             : nullptr;
2432     if (int32_splits) {
2433       *int32_splits++ = 0;
2434     } else if (int64_splits) {
2435       *int64_splits++ = 0;
2436     }
2437 
2438     // Fill in the values.
2439     size_t split = 0;  // = total number of elements we've seen so far
2440     for (int e = 0; e < num_examples; e++) {
2441       const auto& feature_proto = feature.protos[e];
2442       if (!feature_proto.empty()) {
2443         protobuf::io::CodedInputStream stream(
2444             reinterpret_cast<const uint8*>(feature_proto.data()),
2445             feature_proto.size());
2446         EnableAliasing(&stream);
2447         size_t num_added =
2448             ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2449         split += num_added;
2450       }
2451       if (int32_splits) {
2452         *int32_splits++ = split;
2453       } else if (int64_splits) {
2454         *int64_splits++ = split;
2455       }
2456     }
2457     if (split != expected_num_elements) {
2458       return errors::InvalidArgument(
2459           "Unexpected total number of elements in feature ", c.feature_name);
2460     }
2461     if (int32_splits || int64_splits) {
2462       int actual_splits =
2463           int32_splits
2464               ? int32_splits -
2465                     context_result->ragged_splits[t].vec<int32>().data()
2466               : int64_splits -
2467                     context_result->ragged_splits[t].vec<int64>().data();
2468       if (actual_splits != num_examples + 1) {
2469         return errors::InvalidArgument(
2470             "Unexpected number of examples for feature ", c.feature_name);
2471       }
2472     }
2473   }
2474   return Status::OK();
2475 }
2476 
2477 // Parses dense features in `sequence_features`, and writes their parsed
2478 // values to `sequence_result`.
ParseSequenceDenseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths)2479 Status ParseSequenceDenseFeatures(const FeatureProtosMap& sequence_features,
2480                                   const FastParseExampleConfig& sequence_config,
2481                                   gtl::ArraySlice<tstring> example_names,
2482                                   bool is_batch, int num_examples,
2483                                   Allocator* allocator, Result* sequence_result,
2484                                   std::vector<Tensor>* dense_feature_lengths) {
2485   TensorShape dense_length_shape;
2486   if (is_batch) {
2487     dense_length_shape.AddDim(num_examples);
2488   }
2489   for (int t = 0; t < sequence_config.dense.size(); ++t) {
2490     const auto& c = sequence_config.dense[t];
2491     const FeatureProtos& feature =
2492         sequence_features.find(c.feature_name)->second;
2493     TensorShape dense_shape, row_shape;
2494     DataType dtype = c.dtype;
2495     const size_t expected_max_elements = feature.length;
2496     if (!c.shape.AsTensorShape(&row_shape) ||
2497         expected_max_elements !=
2498             (expected_max_elements / row_shape.num_elements()) *
2499                 row_shape.num_elements()) {
2500       PartialTensorShape total_shape = row_shape;
2501       total_shape.InsertDim(0, -1);
2502       return errors::InvalidArgument(
2503           "Feature list '", c.feature_name,
2504           "' has an unexpected number of values.  Total values size: ",
2505           expected_max_elements,
2506           " is not consistent with output shape: ", total_shape.DebugString());
2507     }
2508     int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
2509     if (is_batch) {
2510       dense_shape.AddDim(num_examples);
2511     }
2512     dense_shape.AddDim(expected_max_rows);
2513     for (const int dim : sequence_config.dense[t].shape.dim_sizes()) {
2514       dense_shape.AddDim(dim);
2515     }
2516     sequence_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2517     (*dense_feature_lengths)[t] =
2518         Tensor(allocator, DT_INT64, dense_length_shape);
2519     int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
2520 
2521     tstring* out_bytes = nullptr;
2522     float* out_float = nullptr;
2523     int64* out_int64 = nullptr;
2524     switch (dtype) {
2525       case DT_STRING:
2526         out_bytes = sequence_result->dense_values[t].flat<tstring>().data();
2527         break;
2528       case DT_FLOAT:
2529         out_float = sequence_result->dense_values[t].flat<float>().data();
2530         break;
2531       case DT_INT64:
2532         out_int64 = sequence_result->dense_values[t].flat<int64>().data();
2533         break;
2534       default:
2535         ReportUnexpectedDataType(dtype);
2536     }
2537 
2538     // Fill in the values.
2539     for (int e = 0; e < num_examples; e++) {
2540       size_t num_elements = 0, num_rows = 0;
2541       const auto& feature_proto = feature.protos[e];
2542       if (!feature.protos_present[e]) {
2543         // Return an error if this feature was not allowed to be missing.
2544         // Otherwise, we'll pad as needed below.
2545         if (!c.variable_length) {
2546           return errors::InvalidArgument(
2547               "Name: ", ExampleName(example_names, e), ", Feature list '",
2548               c.feature_name,
2549               "' is required but could not be found.  "
2550               "Did you mean to include it in "
2551               "feature_list_dense_missing_assumed_empty or "
2552               "feature_list_dense_defaults?");
2553         }
2554       } else if (!feature_proto.empty()) {
2555         protobuf::io::CodedInputStream stream(
2556             reinterpret_cast<const uint8*>(feature_proto.data()),
2557             feature_proto.size());
2558         EnableAliasing(&stream);
2559         while (!stream.ExpectAtEnd()) {
2560           uint32 feature_length;
2561           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2562               !stream.ReadVarint32(&feature_length)) {
2563             return errors::InvalidArgument("Error in sequence feature ",
2564                                            c.feature_name, " in example ",
2565                                            ExampleName(example_names, e));
2566           }
2567           auto limit = stream.PushLimit(feature_length);
2568           int num_added = 0;
2569           if (feature_length > 2) {
2570             switch (dtype) {
2571               case DT_STRING:
2572                 num_added = ParseBytesFeature(&stream, out_bytes);
2573                 out_bytes += num_added;
2574                 break;
2575               case DT_FLOAT:
2576                 num_added = ParseFloatFeature(&stream, out_float);
2577                 out_float += num_added;
2578                 break;
2579               case DT_INT64:
2580                 num_added = ParseInt64Feature(&stream, out_int64);
2581                 out_int64 += num_added;
2582                 break;
2583               default:
2584                 ReportUnexpectedDataType(dtype);
2585                 num_added = 0;
2586             }
2587             if (num_added < 0) {
2588               // This should be unreachable -- we already scanned the feature in
2589               // GetSequenceFeatureLengths, and it hasn't changed since then.
2590               return errors::InvalidArgument("Error in sequence feature ",
2591                                              c.feature_name, " in example ",
2592                                              ExampleName(example_names, e));
2593             }
2594           }
2595           if (num_added != row_shape.num_elements()) {
2596             return errors::InvalidArgument(
2597                 "Name: ", ExampleName(example_names, e),
2598                 ", Key: ", c.feature_name, ", Index: ", num_rows,
2599                 ".  Number of values != expected.  values size: ", num_added,
2600                 " but output shape: ", row_shape.DebugString());
2601           }
2602           num_elements += num_added;
2603           num_rows++;
2604           stream.PopLimit(limit);
2605         }
2606       }
2607       *out_lengths++ = num_rows;
2608       // Pad as necessary.
2609       int num_to_pad = expected_max_elements - num_elements;
2610       switch (dtype) {
2611         case DT_STRING:
2612           out_bytes += num_to_pad;
2613           break;
2614         case DT_FLOAT:
2615           PadFloatFeature(num_to_pad, out_float);
2616           out_float += num_to_pad;
2617           break;
2618         case DT_INT64:
2619           PadInt64Feature(num_to_pad, out_int64);
2620           out_int64 += num_to_pad;
2621           break;
2622         default:
2623           ReportUnexpectedDataType(dtype);
2624       }
2625     }
2626   }
2627   return Status::OK();
2628 }
2629 
2630 // Parses sparse features in `sequence_features`, and writes their parsed
2631 // values to `sequence_result`.
ParseSequenceSparseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2632 Status ParseSequenceSparseFeatures(
2633     const FeatureProtosMap& sequence_features,
2634     const FastParseExampleConfig& sequence_config,
2635     gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2636     Allocator* allocator, Result* sequence_result) {
2637   for (int t = 0; t < sequence_config.sparse.size(); ++t) {
2638     const auto& c = sequence_config.sparse[t];
2639     const FeatureProtos& feature =
2640         sequence_features.find(c.feature_name)->second;
2641     TensorShape indices_shape, values_shape;
2642     DataType dtype = c.dtype;
2643     size_t expected_num_elements = feature.length;
2644     indices_shape.AddDim(expected_num_elements);
2645     indices_shape.AddDim(is_batch ? 3 : 2);
2646     values_shape.AddDim(expected_num_elements);
2647     sequence_result->sparse_indices[t] =
2648         Tensor(allocator, DT_INT64, indices_shape);
2649     sequence_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2650     sequence_result->sparse_shapes[t] =
2651         Tensor(allocator, DT_INT64, TensorShape({is_batch ? 3 : 2}));
2652 
2653     tstring* out_bytes = nullptr;
2654     float* out_float = nullptr;
2655     int64* out_int64 = nullptr;
2656     switch (dtype) {
2657       case DT_STRING:
2658         out_bytes = sequence_result->sparse_values[t].flat<tstring>().data();
2659         break;
2660       case DT_FLOAT:
2661         out_float = sequence_result->sparse_values[t].flat<float>().data();
2662         break;
2663       case DT_INT64:
2664         out_int64 = sequence_result->sparse_values[t].flat<int64>().data();
2665         break;
2666       default:
2667         ReportUnexpectedDataType(dtype);
2668     }
2669     int64* out_indices =
2670         sequence_result->sparse_indices[t].flat<int64>().data();
2671     auto out_shape = sequence_result->sparse_shapes[t].vec<int64>();
2672 
2673     // Fill in the values.
2674     size_t num_elements = 0;
2675     size_t max_num_rows = 0;
2676     size_t max_num_cols = 0;
2677     for (int e = 0; e < num_examples; e++) {
2678       const auto& feature_proto = feature.protos[e];
2679       if (feature_proto.empty()) continue;
2680       protobuf::io::CodedInputStream stream(
2681           reinterpret_cast<const uint8*>(feature_proto.data()),
2682           feature_proto.size());
2683       EnableAliasing(&stream);
2684       size_t num_rows = 0;
2685       while (!stream.ExpectAtEnd()) {
2686         uint32 feature_length;
2687         if (!stream.ExpectTag(kDelimitedTag(1)) ||
2688             !stream.ReadVarint32(&feature_length)) {
2689           // This should be unreachable -- we already scanned the feature in
2690           // GetSequenceFeatureLengths, and it hasn't changed since then.
2691           return errors::InvalidArgument("Error in sequence feature ",
2692                                          c.feature_name, " in example ",
2693                                          ExampleName(example_names, e));
2694         }
2695         if (feature_length > 2) {
2696           auto limit = stream.PushLimit(feature_length);
2697           size_t num_added;
2698           switch (dtype) {
2699             case DT_STRING:
2700               num_added = ParseBytesFeature(&stream, out_bytes);
2701               out_bytes += num_added;
2702               break;
2703             case DT_FLOAT:
2704               num_added = ParseFloatFeature(&stream, out_float);
2705               out_float += num_added;
2706               break;
2707             case DT_INT64:
2708               num_added = ParseInt64Feature(&stream, out_int64);
2709               out_int64 += num_added;
2710               break;
2711             default:
2712               ReportUnexpectedDataType(dtype);
2713               num_added = 0;
2714           }
2715           num_elements += num_added;
2716           max_num_cols = std::max(max_num_cols, num_added);
2717           for (int i = 0; i < num_added; i++) {
2718             if (is_batch) *out_indices++ = e;
2719             *out_indices++ = num_rows;
2720             *out_indices++ = i;
2721           }
2722           stream.PopLimit(limit);
2723         } else if (feature_length == 2) {
2724           if (!SkipEmptyFeature(&stream, dtype)) {
2725             // This should be unreachable -- we already scanned the feature in
2726             // GetSequenceFeatureLengths, and it hasn't changed since then.
2727             return errors::InvalidArgument("Error in sequence feature ",
2728                                            c.feature_name, " in example ",
2729                                            ExampleName(example_names, e));
2730           }
2731         } else if (feature_length != 0) {
2732           // This should be unreachable -- we already scanned the feature in
2733           // GetSequenceFeatureLengths, and it hasn't changed since then.
2734           return errors::InvalidArgument("Error in sequence feature ",
2735                                          c.feature_name, " in example ",
2736                                          ExampleName(example_names, e));
2737         }
2738         num_rows++;
2739       }
2740       max_num_rows = std::max(max_num_rows, num_rows);
2741     }
2742     if (num_elements != expected_num_elements) {
2743       return errors::InvalidArgument(
2744           "Unexpected number of elements in feature ", c.feature_name);
2745     }
2746     if (is_batch) {
2747       out_shape(0) = num_examples;
2748       out_shape(1) = max_num_rows;
2749       out_shape(2) = max_num_cols;
2750     } else {
2751       out_shape(0) = max_num_rows;
2752       out_shape(1) = max_num_cols;
2753     }
2754   }
2755   return Status::OK();
2756 }
2757 
2758 // Parses ragged features in `sequence_features`, and writes their parsed
2759 // values to `sequence_result`.
ParseSequenceRaggedFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2760 Status ParseSequenceRaggedFeatures(
2761     const FeatureProtosMap& sequence_features,
2762     const FastParseExampleConfig& sequence_config,
2763     gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2764     Allocator* allocator, Result* sequence_result) {
2765   for (int t = 0; t < sequence_config.ragged.size(); ++t) {
2766     const auto& c = sequence_config.ragged[t];
2767     const FeatureProtos& feature =
2768         sequence_features.find(c.feature_name)->second;
2769     TensorShape values_shape, inner_splits_shape, outer_splits_shape;
2770     DataType dtype = c.dtype;
2771     DataType splits_dtype = c.splits_dtype;
2772     size_t expected_num_elements = feature.length;
2773     size_t expected_num_rows = feature.num_rows;
2774     values_shape.AddDim(expected_num_elements);
2775     inner_splits_shape.AddDim(expected_num_rows + 1);
2776     if (is_batch) {
2777       outer_splits_shape.AddDim(num_examples + 1);
2778     }
2779     sequence_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2780     sequence_result->ragged_splits[t] =
2781         Tensor(allocator, splits_dtype, inner_splits_shape);
2782     sequence_result->ragged_outer_splits[t] =
2783         Tensor(allocator, splits_dtype, outer_splits_shape);
2784     Tensor& out_values = sequence_result->ragged_values[t];
2785     size_t out_values_offset = 0;
2786     int32* int32_inner_splits =
2787         splits_dtype == DT_INT32
2788             ? sequence_result->ragged_splits[t].vec<int32>().data()
2789             : nullptr;
2790     int64* int64_inner_splits =
2791         splits_dtype == DT_INT64
2792             ? sequence_result->ragged_splits[t].vec<int64>().data()
2793             : nullptr;
2794     int32* int32_outer_splits =
2795         is_batch && splits_dtype == DT_INT32
2796             ? sequence_result->ragged_outer_splits[t].vec<int32>().data()
2797             : nullptr;
2798     int64* int64_outer_splits =
2799         is_batch && splits_dtype == DT_INT64
2800             ? sequence_result->ragged_outer_splits[t].vec<int64>().data()
2801             : nullptr;
2802     if (int32_inner_splits) {
2803       *int32_inner_splits++ = 0;
2804     } else if (int64_inner_splits) {
2805       *int64_inner_splits++ = 0;
2806     }
2807     if (int32_outer_splits) {
2808       *int32_outer_splits++ = 0;
2809     } else if (int64_outer_splits) {
2810       *int64_outer_splits++ = 0;
2811     }
2812 
2813     // Fill in the values.
2814     size_t inner_split = 0;  // total number of elements we've seen so far
2815     size_t outer_split = 0;  // total number of rows we've seen so far
2816     for (int e = 0; e < num_examples; e++) {
2817       const auto& feature_proto = feature.protos[e];
2818       if (!feature_proto.empty()) {
2819         protobuf::io::CodedInputStream stream(
2820             reinterpret_cast<const uint8*>(feature_proto.data()),
2821             feature_proto.size());
2822         EnableAliasing(&stream);
2823         while (!stream.ExpectAtEnd()) {
2824           uint32 feature_length;
2825           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2826               !stream.ReadVarint32(&feature_length)) {
2827             // This should be unreachable -- we already scanned the feature in
2828             // GetSequenceFeatureLengths, and it hasn't changed since then.
2829             return errors::InvalidArgument("Error in sequence feature ",
2830                                            c.feature_name, " in example ",
2831                                            ExampleName(example_names, e));
2832           }
2833           if (feature_length > 2) {
2834             auto limit = stream.PushLimit(feature_length);
2835             size_t num_added =
2836                 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2837             inner_split += num_added;
2838             stream.PopLimit(limit);
2839           } else if (feature_length == 2) {
2840             if (!SkipEmptyFeature(&stream, dtype)) {
2841               // This should be unreachable -- we already scanned the feature in
2842               // GetSequenceFeatureLengths, and it hasn't changed since then.
2843               return errors::InvalidArgument("Error in sequence feature ",
2844                                              c.feature_name, " in example ",
2845                                              ExampleName(example_names, e));
2846             }
2847           } else if (feature_length != 0) {
2848             // This should be unreachable -- we already scanned the feature in
2849             // GetSequenceFeatureLengths, and it hasn't changed since then.
2850             return errors::InvalidArgument("Error in sequence feature ",
2851                                            c.feature_name, " in example ",
2852                                            ExampleName(example_names, e));
2853           }
2854           if (int32_inner_splits) {
2855             *int32_inner_splits++ = inner_split;
2856           } else if (int64_inner_splits) {
2857             *int64_inner_splits++ = inner_split;
2858           }
2859           outer_split++;
2860         }
2861       }
2862       if (int32_outer_splits) {
2863         *int32_outer_splits++ = outer_split;
2864       } else if (int64_outer_splits) {
2865         *int64_outer_splits++ = outer_split;
2866       }
2867     }
2868     if (outer_split != expected_num_rows) {
2869       return errors::InvalidArgument("Unexpected number of rows for feature ",
2870                                      c.feature_name);
2871     }
2872     if (inner_split != expected_num_elements) {
2873       return errors::InvalidArgument(
2874           "Unexpected number of elements for feature ", c.feature_name);
2875     }
2876 
2877     if (int32_inner_splits || int64_inner_splits) {
2878       const auto& inner_splits = sequence_result->ragged_splits[t];
2879       int num_inner_splits =
2880           int32_inner_splits
2881               ? int32_inner_splits - inner_splits.vec<int32>().data()
2882               : int64_inner_splits - inner_splits.vec<int64>().data();
2883       if (num_inner_splits != expected_num_rows + 1) {
2884         return errors::InvalidArgument("Unexpected number of rows for feature ",
2885                                        c.feature_name);
2886       }
2887     }
2888     if (int32_outer_splits || int64_outer_splits) {
2889       const auto& outer_splits = sequence_result->ragged_outer_splits[t];
2890       int num_outer_splits =
2891           int32_outer_splits
2892               ? int32_outer_splits - outer_splits.vec<int32>().data()
2893               : int64_outer_splits - outer_splits.vec<int64>().data();
2894       if (num_outer_splits != num_examples + 1) {
2895         return errors::InvalidArgument(
2896             "Unexpected number of examples for feature ", c.feature_name);
2897       }
2898     }
2899   }
2900   return Status::OK();
2901 }
2902 
2903 }  // namespace
2904 
2905 // TODO(sundberg): Use the threadpool to parallelize example parsing.
2906 // TODO(b/111553342): Support extracting feature statistics from the examples.
FastParseSequenceExample(const FastParseExampleConfig & context_config,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * context_result,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths,bool is_batch)2907 Status FastParseSequenceExample(const FastParseExampleConfig& context_config,
2908                                 const FastParseExampleConfig& sequence_config,
2909                                 gtl::ArraySlice<tstring> serialized,
2910                                 gtl::ArraySlice<tstring> example_names,
2911                                 thread::ThreadPool* thread_pool,
2912                                 Result* context_result, Result* sequence_result,
2913                                 std::vector<Tensor>* dense_feature_lengths,
2914                                 bool is_batch) {
2915   int num_examples = serialized.size();
2916   DCHECK(context_result != nullptr);
2917   DCHECK(sequence_result != nullptr);
2918   DCHECK(dense_feature_lengths != nullptr);
2919   size_t num_context_features = context_config.sparse.size() +
2920                                 context_config.dense.size() +
2921                                 context_config.ragged.size();
2922   FeatureProtosMap context_features;
2923   context_features.reserve(num_context_features);
2924 
2925   if (!example_names.empty() && example_names.size() != num_examples) {
2926     return errors::InvalidArgument(
2927         "example_names must be empty or have the correct number of elements");
2928   }
2929   for (auto& c : context_config.sparse) {
2930     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2931     FeatureProtos& feature = context_features[c.feature_name];
2932     feature.dtype = c.dtype;
2933     feature.length = 0;
2934     feature.type = Type::Sparse;
2935     feature.protos.resize(num_examples);
2936     feature.protos_present.resize(num_examples);
2937   }
2938   for (auto& c : context_config.ragged) {
2939     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2940     FeatureProtos& feature = context_features[c.feature_name];
2941     if (feature.type == Type::Sparse) {
2942       return errors::InvalidArgument("Context feature " + c.feature_name +
2943                                      " cannot be both ragged and sparse");
2944     }
2945     feature.dtype = c.dtype;
2946     feature.length = 0;
2947     feature.type = Type::Ragged;
2948     feature.protos.resize(num_examples);
2949     feature.protos_present.resize(num_examples);
2950   }
2951   for (auto& c : context_config.dense) {
2952     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2953     FeatureProtos& feature = context_features[c.feature_name];
2954     if (feature.type != Type::Dense) {
2955       return errors::InvalidArgument("Context feature " + c.feature_name +
2956                                      " cannot be both dense and sparse");
2957     }
2958     if (c.default_value.NumElements() > 0) {
2959       if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
2960         return errors::InvalidArgument("Default value for context feature ",
2961                                        c.feature_name,
2962                                        " has an incorrect shape: saw ",
2963                                        c.default_value.shape().DebugString(),
2964                                        " but expected ", c.shape.DebugString());
2965       }
2966     }
2967     feature.dtype = c.dtype;
2968     feature.length = c.default_value.NumElements();
2969     feature.protos.resize(num_examples);
2970     feature.protos_present.resize(num_examples);
2971   }
2972   size_t num_sequence_features = sequence_config.sparse.size() +
2973                                  sequence_config.dense.size() +
2974                                  sequence_config.ragged.size();
2975   FeatureProtosMap sequence_features;
2976   sequence_features.reserve(num_sequence_features);
2977   for (auto& c : sequence_config.sparse) {
2978     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2979     FeatureProtos& feature = sequence_features[c.feature_name];
2980     feature.dtype = c.dtype;
2981     feature.length = 0;
2982     feature.type = Type::Sparse;
2983     feature.protos.resize(num_examples);
2984     feature.protos_present.resize(num_examples);
2985   }
2986   for (auto& c : sequence_config.ragged) {
2987     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2988     FeatureProtos& feature = sequence_features[c.feature_name];
2989     if (feature.type == Type::Sparse) {
2990       return errors::InvalidArgument("Sequence feature " + c.feature_name +
2991                                      " cannot be both ragged and sparse");
2992     }
2993     feature.dtype = c.dtype;
2994     feature.length = 0;
2995     feature.type = Type::Ragged;
2996     feature.protos.resize(num_examples);
2997     feature.protos_present.resize(num_examples);
2998   }
2999   for (auto& c : sequence_config.dense) {
3000     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
3001     FeatureProtos& feature = sequence_features[c.feature_name];
3002     if (feature.type != Type::Dense) {
3003       return errors::InvalidArgument("Sequence feature " + c.feature_name +
3004                                      " cannot be both dense and sparse");
3005     }
3006     feature.dtype = c.dtype;
3007     feature.length = 0;
3008     feature.protos.resize(num_examples);
3009     feature.protos_present.resize(num_examples);
3010   }
3011 
3012   // Find the serialized proto substrings for each feature.
3013   TF_RETURN_IF_ERROR(ExtractFeaturesFromSequenceExamples(
3014       serialized, example_names, &context_features, &sequence_features));
3015 
3016   // Scan through the protos to determine how much memory we need to allocate.
3017   TF_RETURN_IF_ERROR(
3018       GetContextFeatureLengths(example_names, &context_features));
3019   TF_RETURN_IF_ERROR(
3020       GetSequenceFeatureLengths(example_names, &sequence_features));
3021 
3022   // Allocate memory.
3023   context_result->sparse_values.resize(context_config.sparse.size());
3024   context_result->sparse_indices.resize(context_config.sparse.size());
3025   context_result->sparse_shapes.resize(context_config.sparse.size());
3026   context_result->dense_values.resize(context_config.dense.size());
3027   context_result->ragged_values.resize(context_config.ragged.size());
3028   context_result->ragged_splits.resize(context_config.ragged.size());
3029   context_result->ragged_outer_splits.resize(context_config.ragged.size());
3030   sequence_result->sparse_values.resize(sequence_config.sparse.size());
3031   sequence_result->sparse_indices.resize(sequence_config.sparse.size());
3032   sequence_result->sparse_shapes.resize(sequence_config.sparse.size());
3033   sequence_result->dense_values.resize(sequence_config.dense.size());
3034   sequence_result->ragged_values.resize(sequence_config.ragged.size());
3035   sequence_result->ragged_splits.resize(sequence_config.ragged.size());
3036   sequence_result->ragged_outer_splits.resize(sequence_config.ragged.size());
3037   dense_feature_lengths->resize(sequence_config.dense.size());
3038 
3039   // NOTE(mrry): Cache the CPU allocator here and use it in Tensor construction,
3040   // to avoid lock contention in `tensorflow::cpu_allocator()`.
3041   Allocator* allocator = tensorflow::cpu_allocator();
3042 
3043   TF_RETURN_IF_ERROR(ParseContextDenseFeatures(
3044       context_features, context_config, example_names, is_batch, num_examples,
3045       allocator, context_result));
3046   TF_RETURN_IF_ERROR(ParseContextSparseFeatures(
3047       context_features, context_config, example_names, is_batch, num_examples,
3048       allocator, context_result));
3049   TF_RETURN_IF_ERROR(ParseContextRaggedFeatures(
3050       context_features, context_config, example_names, is_batch, num_examples,
3051       allocator, context_result));
3052   TF_RETURN_IF_ERROR(ParseSequenceDenseFeatures(
3053       sequence_features, sequence_config, example_names, is_batch, num_examples,
3054       allocator, sequence_result, dense_feature_lengths));
3055   TF_RETURN_IF_ERROR(ParseSequenceSparseFeatures(
3056       sequence_features, sequence_config, example_names, is_batch, num_examples,
3057       allocator, sequence_result));
3058   TF_RETURN_IF_ERROR(ParseSequenceRaggedFeatures(
3059       sequence_features, sequence_config, example_names, is_batch, num_examples,
3060       allocator, sequence_result));
3061 
3062   return Status::OK();
3063 }
3064 
3065 }  // namespace example
3066 }  // namespace tensorflow
3067