1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_fast_parsing.h"
16
17 #include <vector>
18
19 #include "absl/base/casts.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/example/example.pb.h"
22 #include "tensorflow/core/example/feature.pb.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/blocking_counter.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/monitoring/counter.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/util/presized_cuckoo_map.h"
37 #include "tensorflow/core/util/sparse/sparse_tensor.h"
38
39 namespace tensorflow {
40 namespace example {
41
42 namespace {
43
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46
47 template <typename T>
48 class LimitedArraySlice {
49 public:
50 using value_type = T;
51
LimitedArraySlice(T * begin,size_t num_elements)52 LimitedArraySlice(T* begin, size_t num_elements)
53 : current_(begin), begin_(begin), end_(begin + num_elements) {}
54
55 // May return negative if there were push_back calls after slice was filled.
EndDistance() const56 int64 EndDistance() const { return end_ - current_; }
57
58 // Attempts to push value to the back of this. If the slice has
59 // already been filled, this method has no effect on the underlying data, but
60 // it changes the number returned by EndDistance into negative values.
push_back(T && value)61 void push_back(T&& value) {
62 if (EndDistance() > 0) *current_ = std::move(value);
63 ++current_;
64 }
65
66 // "Constructs" an element at the back of this by resizing the slice, and
67 // returns a mutable reference to the new last element.
68 // REQUIRES: EndDistance() > 0.
construct_at_end()69 T& construct_at_end() {
70 DCHECK_GT(EndDistance(), 0);
71 return *(current_++);
72 }
73
74 // Returns a mutable reference to the last element in the slice.
75 // REQUIRES: size() > 0.
back()76 T& back() { return *(current_ - 1); }
77
78 // Returns the number of elements in the slice.
size() const79 size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80
81 // Attempts to resize the vector to the given size. It does so by advancing
82 // the pointer to the current element, possibly beyond the end of the slice.
83 // As a consequence, calling `size()` after `resize(x)` was called might
84 // return a value less than `x`.
resize(size_t size)85 void resize(size_t size) { current_ = begin_ + size; }
86
87 // Returns the pointer to the underlying data buffer.
data()88 T* data() { return begin_; }
89
90 private:
91 T* current_;
92 T* begin_;
93 T* end_;
94 };
95
96 template <typename A>
EnableAliasing(A * a)97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98 a->EnableAliasing(true);
99 }
100
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103
PeekTag(protobuf::io::CodedInputStream * stream)104 uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
105 DCHECK(stream != nullptr);
106 const void* ptr;
107 int size;
108 if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
109 return *static_cast<const uint8*>(ptr);
110 }
111
kVarintTag(uint32 tag)112 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)113 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)114 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
115
116 namespace parsed {
117
118 // ParseDataType has to be called first, then appropriate ParseZzzzList.
119 class Feature {
120 public:
Feature()121 Feature() {}
Feature(StringPiece serialized)122 explicit Feature(StringPiece serialized) : serialized_(serialized) {}
123
ParseDataType(DataType * dtype)124 Status ParseDataType(DataType* dtype) {
125 DCHECK(dtype != nullptr);
126 if (serialized_.empty()) {
127 *dtype = DT_INVALID;
128 return Status::OK();
129 }
130 uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
131 serialized_.remove_prefix(1);
132 switch (oneof_tag) {
133 case kDelimitedTag(1):
134 *dtype = DT_STRING;
135 break;
136 case kDelimitedTag(2):
137 *dtype = DT_FLOAT;
138 break;
139 case kDelimitedTag(3):
140 *dtype = DT_INT64;
141 break;
142 default:
143 // Initialize variable to avoid compiler warning
144 *dtype = DT_INVALID;
145 return errors::InvalidArgument("Unsupported datatype.");
146 }
147 return Status::OK();
148 }
149
GetNumElementsInBytesList(int * num_elements)150 bool GetNumElementsInBytesList(int* num_elements) {
151 protobuf::io::CodedInputStream stream(
152 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
153 EnableAliasing(&stream);
154 uint32 length = 0;
155 if (!stream.ReadVarint32(&length)) return false;
156 auto limit = stream.PushLimit(length);
157 *num_elements = 0;
158 while (!stream.ExpectAtEnd()) {
159 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
160 uint32 bytes_length = 0;
161 if (!stream.ReadVarint32(&bytes_length)) return false;
162 if (!stream.Skip(bytes_length)) return false;
163 ++*num_elements;
164 }
165 stream.PopLimit(limit);
166 return true;
167 }
168
169 // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)170 tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
171 if (bytes_list->EndDistance() <= 0) {
172 return nullptr;
173 }
174 return &bytes_list->construct_at_end();
175 }
construct_at_end(SmallVector<tstring> * bytes_list)176 tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
177 return &bytes_list->emplace_back();
178 }
179
180 template <typename Result>
ParseBytesList(Result * bytes_list)181 bool ParseBytesList(Result* bytes_list) {
182 DCHECK(bytes_list != nullptr);
183
184 protobuf::io::CodedInputStream stream(
185 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
186
187 EnableAliasing(&stream);
188
189 uint32 length;
190 if (!stream.ReadVarint32(&length)) return false;
191 auto limit = stream.PushLimit(length);
192
193 while (!stream.ExpectAtEnd()) {
194 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
195 // parse string
196 uint32 bytes_length;
197 if (!stream.ReadVarint32(&bytes_length)) return false;
198 tstring* bytes = construct_at_end(bytes_list);
199 if (bytes == nullptr) return false;
200 bytes->resize_uninitialized(bytes_length);
201 if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
202 }
203 stream.PopLimit(limit);
204 return true;
205 }
206
207 template <typename Result>
ParseFloatList(Result * float_list)208 bool ParseFloatList(Result* float_list) {
209 DCHECK(float_list != nullptr);
210 protobuf::io::CodedInputStream stream(
211 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
212 EnableAliasing(&stream);
213 uint32 length;
214 if (!stream.ReadVarint32(&length)) return false;
215 auto limit = stream.PushLimit(length);
216
217 if (!stream.ExpectAtEnd()) {
218 uint8 peek_tag = PeekTag(&stream);
219 if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
220 return false;
221 }
222
223 constexpr int32 kNumFloatBytes = 4;
224 if (peek_tag == kDelimitedTag(1)) { // packed
225 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
226 uint32 packed_length;
227 if (!stream.ReadVarint32(&packed_length)) return false;
228 auto packed_limit = stream.PushLimit(packed_length);
229
230 // Store the initial size to know the offset we have to start writing
231 // data from before resizing the output "vector".
232 const size_t initial_size = float_list->size();
233 float_list->resize(initial_size + packed_length / kNumFloatBytes);
234
235 // If the result data type is float and we are on a little endian
236 // machine then we can simply memcpy the data from the proto into the
237 // result vector.
238 if (port::kLittleEndian &&
239 sizeof(typename Result::value_type) == kNumFloatBytes) {
240 // Calculate the length of the buffer available what can be less than
241 // what we requested in resize in case of a LimitedArraySlice.
242 const uint32 bytes_to_copy =
243 std::min(static_cast<uint32>((float_list->size() - initial_size) *
244 kNumFloatBytes),
245 packed_length);
246 if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
247 return false;
248 } else {
249 int64 index = initial_size;
250 while (!stream.ExpectAtEnd()) {
251 uint32 buffer32;
252 if (!stream.ReadLittleEndian32(&buffer32)) return false;
253 if (index < float_list->size()) {
254 float_list->data()[index] = absl::bit_cast<float>(buffer32);
255 ++index;
256 }
257 }
258 }
259
260 stream.PopLimit(packed_limit);
261 } else { // non-packed
262 const size_t initial_size = float_list->size();
263 // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
264 // the value.
265 const int64 num_elements =
266 stream.BytesUntilLimit() / (1 + kNumFloatBytes);
267 float_list->resize(initial_size + num_elements);
268 int64 index = initial_size;
269 while (!stream.ExpectAtEnd()) {
270 if (!stream.ExpectTag(kFixed32Tag(1))) return false;
271 uint32 buffer32;
272 if (!stream.ReadLittleEndian32(&buffer32)) return false;
273 float_list->data()[index] = absl::bit_cast<float>(buffer32);
274 ++index;
275 }
276 }
277 }
278
279 stream.PopLimit(limit);
280 return true;
281 }
282
283 template <typename Result>
ParseInt64List(Result * int64_list)284 bool ParseInt64List(Result* int64_list) {
285 DCHECK(int64_list != nullptr);
286 protobuf::io::CodedInputStream stream(
287 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
288 EnableAliasing(&stream);
289 uint32 length;
290 if (!stream.ReadVarint32(&length)) return false;
291 auto limit = stream.PushLimit(length);
292
293 if (!stream.ExpectAtEnd()) {
294 uint8 peek_tag = PeekTag(&stream);
295 if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
296 return false;
297 }
298 if (peek_tag == kDelimitedTag(1)) { // packed
299 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
300 uint32 packed_length;
301 if (!stream.ReadVarint32(&packed_length)) return false;
302 auto packed_limit = stream.PushLimit(packed_length);
303
304 while (!stream.ExpectAtEnd()) {
305 protobuf_uint64 n; // There is no API for int64
306 if (!stream.ReadVarint64(&n)) return false;
307 int64_list->push_back(static_cast<int64>(n));
308 }
309
310 stream.PopLimit(packed_limit);
311 } else { // non-packed
312 while (!stream.ExpectAtEnd()) {
313 if (!stream.ExpectTag(kVarintTag(1))) return false;
314 protobuf_uint64 n; // There is no API for int64
315 if (!stream.ReadVarint64(&n)) return false;
316 int64_list->push_back(static_cast<int64>(n));
317 }
318 }
319 }
320 stream.PopLimit(limit);
321 return true;
322 }
323
GetSerialized() const324 StringPiece GetSerialized() const { return serialized_; }
325
326 private:
327 // TODO(lew): Pair of uint8* would be more natural.
328 StringPiece serialized_;
329 };
330
331 using FeatureMapEntry = std::pair<StringPiece, Feature>;
332 using Example = std::vector<FeatureMapEntry>;
333
334 } // namespace parsed
335
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)336 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
337 uint32 data;
338 protobuf_uint64 dummy;
339 switch (stream->ReadTag() & 0x7) {
340 case 0: // varint
341 if (!stream->ReadVarint32(&data)) return false;
342 return true;
343 case 1: // fixed64
344 if (!stream->ReadLittleEndian64(&dummy)) return false;
345 return true;
346 case 2: // length delimited
347 if (!stream->ReadVarint32(&data)) return false;
348 stream->Skip(data);
349 return true;
350 case 3: // group begin
351 return false; // groups not supported.
352 case 4: // group end
353 return false; // groups not supported.
354 case 5: // fixed32
355 if (!stream->ReadLittleEndian32(&data)) return false;
356 return true;
357 }
358 return false; // unrecognized tag type
359 }
360
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)361 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
362 DCHECK(stream != nullptr);
363 DCHECK(result != nullptr);
364 uint32 length;
365 if (!stream->ReadVarint32(&length)) return false;
366 if (length == 0) {
367 *result = StringPiece(nullptr, 0);
368 return true;
369 }
370 const void* stream_alias;
371 int stream_size;
372 if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
373 return false;
374 }
375 if (static_cast<uint32>(stream_size) < length) return false;
376 *result = StringPiece(static_cast<const char*>(stream_alias), length);
377 stream->Skip(length);
378 return true;
379 }
380
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)381 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
382 parsed::FeatureMapEntry* feature_map_entry) {
383 DCHECK(stream != nullptr);
384 DCHECK(feature_map_entry != nullptr);
385 uint32 length;
386 if (!stream->ReadVarint32(&length)) return false;
387 auto limit = stream->PushLimit(length);
388 if (!stream->ExpectTag(kDelimitedTag(1))) return false;
389 if (!ParseString(stream, &feature_map_entry->first)) return false;
390 if (!stream->ExpectTag(kDelimitedTag(2))) return false;
391 StringPiece feature_string_piece;
392 if (!ParseString(stream, &feature_string_piece)) return false;
393 feature_map_entry->second = parsed::Feature(feature_string_piece);
394 if (!stream->ExpectAtEnd()) return false;
395 stream->PopLimit(limit);
396 return true;
397 }
398
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)399 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
400 parsed::Example* example) {
401 DCHECK(stream != nullptr);
402 DCHECK(example != nullptr);
403 uint32 length;
404 if (!stream->ReadVarint32(&length)) return false;
405 auto limit = stream->PushLimit(length);
406 while (!stream->ExpectAtEnd()) {
407 parsed::FeatureMapEntry feature_map_entry;
408 if (!stream->ExpectTag(kDelimitedTag(1))) return false;
409 if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
410 example->push_back(std::move(feature_map_entry));
411 }
412 stream->PopLimit(limit);
413 return true;
414 }
415
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)416 bool ParseExample(protobuf::io::CodedInputStream* stream,
417 parsed::Example* example) {
418 DCHECK(stream != nullptr);
419 DCHECK(example != nullptr);
420 // Loop over the input stream which may contain multiple serialized Example
421 // protos merged together as strings. This behavior is consistent with Proto's
422 // ParseFromString when string representations are concatenated.
423 while (!stream->ExpectAtEnd()) {
424 if (!stream->ExpectTag(kDelimitedTag(1))) {
425 if (!SkipExtraneousTag(stream)) return false;
426 } else {
427 if (!ParseFeatures(stream, example)) return false;
428 }
429 }
430 return true;
431 }
432
ParseExample(StringPiece serialized,parsed::Example * example)433 bool ParseExample(StringPiece serialized, parsed::Example* example) {
434 DCHECK(example != nullptr);
435 protobuf::io::CodedInputStream stream(
436 reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
437 EnableAliasing(&stream);
438 return ParseExample(&stream, example);
439 }
440
441 } // namespace
442
TestFastParse(const string & serialized,Example * example)443 bool TestFastParse(const string& serialized, Example* example) {
444 DCHECK(example != nullptr);
445 parsed::Example parsed_example;
446 if (!ParseExample(serialized, &parsed_example)) return false;
447 auto& features = *example->mutable_features();
448 size_t parsed_example_size = parsed_example.size();
449 for (size_t i = 0; i < parsed_example_size; ++i) {
450 // This is a logic that standard protobuf parsing is implementing.
451 // I.e. last entry in the map overwrites all the previous ones.
452 parsed::FeatureMapEntry& name_and_feature =
453 parsed_example[parsed_example_size - i - 1];
454 string name(name_and_feature.first);
455 if ((*features.mutable_feature()).count(name) > 0) continue;
456
457 auto& value = (*features.mutable_feature())[name];
458 DataType dtype;
459 if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false;
460 switch (dtype) {
461 case DT_INVALID:
462 break;
463 case DT_STRING: {
464 SmallVector<tstring> list;
465 if (!name_and_feature.second.ParseBytesList(&list)) return false;
466 auto* result_list = value.mutable_bytes_list();
467 for (auto& bytes : list) {
468 result_list->add_value(bytes.data(), bytes.size());
469 }
470 break;
471 }
472 case DT_FLOAT: {
473 SmallVector<float> list;
474 if (!name_and_feature.second.ParseFloatList(&list)) return false;
475 auto* result_list = value.mutable_float_list();
476 for (float f : list) {
477 result_list->add_value(f);
478 }
479 break;
480 }
481 case DT_INT64: {
482 SmallVector<int64> list;
483 if (!name_and_feature.second.ParseInt64List(&list)) return false;
484 auto* result_list = value.mutable_int64_list();
485 for (int64 i : list) {
486 result_list->add_value(i);
487 }
488 break;
489 }
490 default:
491 LOG(FATAL) << "Should not happen.";
492 }
493 }
494 return true;
495 }
496
497 // -----------------------------------------------------------------------------
498
499 namespace {
500
501 using Config = FastParseExampleConfig;
502
ParallelFor(const std::function<void (size_t)> & f,size_t n,thread::ThreadPool * thread_pool)503 void ParallelFor(const std::function<void(size_t)>& f, size_t n,
504 thread::ThreadPool* thread_pool) {
505 if (n == 0) return;
506 if (thread_pool == nullptr) {
507 for (size_t i = 0; i < n; ++i) {
508 f(i);
509 }
510 } else {
511 BlockingCounter counter(n - 1);
512 for (size_t i = 1; i < n; ++i) {
513 thread_pool->Schedule([i, &f, &counter] {
514 f(i);
515 counter.DecrementCount();
516 });
517 }
518 f(0);
519 counter.Wait();
520 }
521 }
522
523 // Enumeration for distinguishing feature types.
524 // Note: FastParseSequenceExample constructs a map that includes Type values,
525 // and relies on the fact that they are default-initialized to Dense.
526 enum class Type { Dense, Sparse, Ragged };
527
528 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
529 struct SparseBuffer {
530 // Features are in one of the 3 vectors below depending on config's dtype.
531 // Other 2 vectors remain empty.
532 SmallVector<tstring> bytes_list;
533 SmallVector<float> float_list;
534 SmallVector<int64> int64_list;
535
536 // Features of example i are elements with indices
537 // from example_end_indices[i-1] to example_end_indices[i]-1 on the
538 // appropriate xxxxx_list
539 std::vector<size_t> example_end_indices;
540 };
541
542 struct SeededHasher {
operator ()tensorflow::example::__anon4c07364d0211::SeededHasher543 uint64 operator()(StringPiece s) const {
544 return Hash64(s.data(), s.size(), seed);
545 }
546 uint64 seed{0xDECAFCAFFE};
547 };
548
LogDenseFeatureDataLoss(StringPiece feature_name)549 void LogDenseFeatureDataLoss(StringPiece feature_name) {
550 LOG(WARNING) << "Data loss! Feature '" << feature_name
551 << "' is present in multiple concatenated "
552 "tf.Examples. Ignoring all but last one.";
553 static auto* duplicated_dense_feature = monitoring::Counter<0>::New(
554 "/tensorflow/core/util/example_proto_fast_parsing/"
555 "duplicated_dense_feature",
556 "Dense feature appears twice in a tf.Example");
557 duplicated_dense_feature->GetCell()->IncrementBy(1);
558 }
559
LogSparseFeatureDataLoss(StringPiece feature_name)560 void LogSparseFeatureDataLoss(StringPiece feature_name) {
561 LOG(WARNING) << "Data loss! Feature '" << feature_name
562 << "' is present in multiple concatenated "
563 "tf.Examples. Ignoring all but last one.";
564 static auto* duplicated_sparse_feature = monitoring::Counter<0>::New(
565 "/tensorflow/core/util/example_proto_fast_parsing/"
566 "duplicated_sparse_feature",
567 "Sparse feature appears twice in a tf.Example");
568 duplicated_sparse_feature->GetCell()->IncrementBy(1);
569 }
570
FastParseSerializedExample(const tstring & serialized_example,const tstring & example_name,const size_t example_index,const Config & config,const PresizedCuckooMap<std::pair<size_t,Type>> & config_index,SeededHasher hasher,std::vector<Tensor> * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,std::vector<SparseBuffer> * output_ragged,PerExampleFeatureStats * output_stats)571 Status FastParseSerializedExample(
572 const tstring& serialized_example, const tstring& example_name,
573 const size_t example_index, const Config& config,
574 const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
575 SeededHasher hasher, std::vector<Tensor>* output_dense,
576 std::vector<SparseBuffer>* output_varlen_dense,
577 std::vector<SparseBuffer>* output_sparse,
578 std::vector<SparseBuffer>* output_ragged,
579 PerExampleFeatureStats* output_stats) {
580 DCHECK(output_dense != nullptr);
581 DCHECK(output_sparse != nullptr);
582 DCHECK(output_ragged != nullptr);
583 parsed::Example parsed_example;
584 if (!ParseExample(serialized_example, &parsed_example)) {
585 return errors::InvalidArgument("Could not parse example input, value: '",
586 serialized_example, "'");
587 }
588 std::vector<int64> sparse_feature_last_example(config.sparse.size(), -1);
589 std::vector<int64> dense_feature_last_example(config.dense.size(), -1);
590 std::vector<int64> ragged_feature_last_example(config.ragged.size(), -1);
591
592 // Handle features present in the example.
593 const size_t parsed_example_size = parsed_example.size();
594
595 if (output_stats) {
596 // TODO(b/111553342): This may over-count the number of features if there
597 // are duplicate keys in the feature map. Consider deduplicating the keys
598 // before computing the count.
599 output_stats->features_count = parsed_example_size;
600 }
601
602 for (size_t i = 0; i < parsed_example_size; ++i) {
603 // This is a logic that standard protobuf parsing is implementing.
604 // I.e. last entry in the map overwrites all the previous ones.
605 parsed::FeatureMapEntry& name_and_feature =
606 parsed_example[parsed_example_size - i - 1];
607
608 const StringPiece feature_name = name_and_feature.first;
609 parsed::Feature& feature = name_and_feature.second;
610
611 std::pair<size_t, Type> d_and_type;
612 uint64 h = hasher(feature_name);
613 if (!config_index.Find(h, &d_and_type)) continue;
614
615 size_t d = d_and_type.first;
616 bool is_dense = d_and_type.second == Type::Dense;
617 bool is_ragged = d_and_type.second == Type::Ragged;
618
619 {
620 // Testing for PresizedCuckooMap collision.
621 // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
622 const tstring& config_feature_name =
623 is_dense ? config.dense[d].feature_name
624 : (is_ragged ? config.ragged[d].feature_name
625 : config.sparse[d].feature_name);
626 if (feature_name != config_feature_name) continue;
627 }
628
629 auto example_error = [&](StringPiece suffix) {
630 return errors::InvalidArgument("Name: ", example_name,
631 ", Key: ", feature_name,
632 ", Index: ", example_index, ". ", suffix);
633 };
634
635 auto parse_error = [&] {
636 return example_error("Can't parse serialized Example.");
637 };
638
639 DataType example_dtype;
640 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
641
642 if (is_dense) {
643 if (example_dtype == DT_INVALID) continue;
644
645 // If feature was already visited, skip.
646 // Compare comment at the beginning of the loop.
647 if (dense_feature_last_example[d] == example_index) {
648 LogDenseFeatureDataLoss(feature_name);
649 continue;
650 }
651 dense_feature_last_example[d] = example_index;
652
653 if (example_dtype != config.dense[d].dtype) {
654 return example_error(strings::StrCat(
655 "Data types don't match. Data type: ",
656 DataTypeString(example_dtype),
657 " but expected type: ", DataTypeString(config.dense[d].dtype)));
658 }
659 if (!config.dense[d].variable_length) {
660 Tensor& out = (*output_dense)[d];
661
662 const std::size_t num_elements = config.dense[d].elements_per_stride;
663 if (output_stats) {
664 // TODO(b/111553342): If desirable, we could add support for counting
665 // elements in the features that aren't parsed, but this could add
666 // considerable runtime cost.
667 output_stats->feature_values_count += num_elements;
668 }
669
670 const std::size_t offset = example_index * num_elements;
671
672 auto shape_error = [&](size_t size, StringPiece type_str) {
673 return example_error(strings::StrCat(
674 "Number of ", type_str,
675 " values != expected. "
676 "Values size: ",
677 size,
678 " but output shape: ", config.dense[d].shape.DebugString()));
679 };
680
681 switch (config.dense[d].dtype) {
682 case DT_INT64: {
683 auto out_p = out.flat<int64>().data() + offset;
684 LimitedArraySlice<int64> slice(out_p, num_elements);
685 if (!feature.ParseInt64List(&slice)) return parse_error();
686 if (slice.EndDistance() != 0) {
687 return shape_error(num_elements - slice.EndDistance(), "int64");
688 }
689 break;
690 }
691 case DT_FLOAT: {
692 auto out_p = out.flat<float>().data() + offset;
693 LimitedArraySlice<float> slice(out_p, num_elements);
694 if (!feature.ParseFloatList(&slice)) return parse_error();
695 if (slice.EndDistance() != 0) {
696 return shape_error(num_elements - slice.EndDistance(), "float");
697 }
698 break;
699 }
700 case DT_STRING: {
701 auto out_p = out.flat<tstring>().data() + offset;
702 LimitedArraySlice<tstring> slice(out_p, num_elements);
703 if (!feature.ParseBytesList(&slice)) return parse_error();
704 if (slice.EndDistance() != 0) {
705 return shape_error(num_elements - slice.EndDistance(), "bytes");
706 }
707 break;
708 }
709 default:
710 LOG(FATAL) << "Should not happen.";
711 }
712 } else { // if variable length
713 SparseBuffer& out = (*output_varlen_dense)[d];
714
715 const std::size_t num_elements = config.dense[d].elements_per_stride;
716
717 if (example_dtype != DT_INVALID &&
718 example_dtype != config.dense[d].dtype) {
719 return example_error(strings::StrCat(
720 "Data types don't match. ",
721 "Expected type: ", DataTypeString(config.dense[d].dtype)));
722 }
723
724 auto shape_error = [&](size_t size, StringPiece type_str) {
725 return example_error(strings::StrCat(
726 "Number of ", type_str,
727 " values is not a multiple of stride length. Saw ", size,
728 " values but output shape is: ",
729 config.dense[d].shape.DebugString()));
730 };
731
732 switch (config.dense[d].dtype) {
733 case DT_INT64: {
734 if (example_dtype != DT_INVALID) {
735 if (!feature.ParseInt64List(&out.int64_list)) {
736 return parse_error();
737 }
738 if (out.int64_list.size() % num_elements != 0) {
739 return shape_error(out.int64_list.size(), "int64");
740 }
741 }
742 out.example_end_indices.push_back(out.int64_list.size());
743 break;
744 }
745 case DT_FLOAT: {
746 if (example_dtype != DT_INVALID) {
747 if (!feature.ParseFloatList(&out.float_list)) {
748 return parse_error();
749 }
750 if (out.float_list.size() % num_elements != 0) {
751 return shape_error(out.float_list.size(), "float");
752 }
753 }
754 out.example_end_indices.push_back(out.float_list.size());
755 break;
756 }
757 case DT_STRING: {
758 if (example_dtype != DT_INVALID) {
759 if (!feature.ParseBytesList(&out.bytes_list)) {
760 return parse_error();
761 }
762 if (out.bytes_list.size() % num_elements != 0) {
763 return shape_error(out.bytes_list.size(), "bytes");
764 }
765 }
766 out.example_end_indices.push_back(out.bytes_list.size());
767 break;
768 }
769 default:
770 LOG(FATAL) << "Should not happen.";
771 }
772
773 if (output_stats) {
774 // Use `out.example_end_indices` to determine the feature-value count
775 // for this feature, because the preceding switch statement pushes
776 // the length of the appropriate feature list to that vector.
777 // TODO(b/111553342): If desirable, we could add support for counting
778 // elements in the features that aren't parsed, but this could add
779 // considerable runtime cost.
780 const size_t out_examples_count = out.example_end_indices.size();
781 if (out_examples_count == 1) {
782 output_stats->feature_values_count += out.example_end_indices[0];
783 } else {
784 output_stats->feature_values_count +=
785 out.example_end_indices[out_examples_count - 1] -
786 out.example_end_indices[out_examples_count - 2];
787 }
788 }
789 }
790 } else {
791 // Feature is sparse or ragged.
792 auto& last_example =
793 is_ragged ? ragged_feature_last_example : sparse_feature_last_example;
794
795 // If feature was already visited, skip.
796 // Compare comment at the beginning of the loop.
797 if (last_example[d] == example_index) {
798 LogSparseFeatureDataLoss(feature_name);
799 continue;
800 }
801 last_example[d] = example_index;
802
803 // Handle sparse features.
804 SparseBuffer& out = is_ragged ? (*output_ragged)[d] : (*output_sparse)[d];
805 DataType feature_dtype =
806 is_ragged ? config.ragged[d].dtype : config.sparse[d].dtype;
807 if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
808 return example_error(
809 strings::StrCat("Data types don't match. ",
810 "Expected type: ", DataTypeString(feature_dtype),
811 ", Actual type: ", DataTypeString(example_dtype)));
812 }
813
814 switch (feature_dtype) {
815 case DT_INT64: {
816 if (example_dtype != DT_INVALID) {
817 if (!feature.ParseInt64List(&out.int64_list)) {
818 return parse_error();
819 }
820 }
821 out.example_end_indices.push_back(out.int64_list.size());
822 break;
823 }
824 case DT_FLOAT: {
825 if (example_dtype != DT_INVALID) {
826 if (!feature.ParseFloatList(&out.float_list)) {
827 return parse_error();
828 }
829 }
830 out.example_end_indices.push_back(out.float_list.size());
831 break;
832 }
833 case DT_STRING: {
834 if (example_dtype != DT_INVALID) {
835 if (!feature.ParseBytesList(&out.bytes_list)) {
836 return parse_error();
837 }
838 }
839 out.example_end_indices.push_back(out.bytes_list.size());
840 break;
841 }
842 default:
843 LOG(FATAL) << "Should not happen.";
844 }
845
846 if (output_stats) {
847 // Use `out.example_end_indices` to determine the feature-value count
848 // for this feature, because the preceding switch statement pushes
849 // the length of the appropriate feature list to that vector.
850 // TODO(b/111553342): If desirable, we could add support for counting
851 // elements in the features that aren't parsed, but this could add
852 // considerable runtime cost.
853 const size_t out_examples_count = out.example_end_indices.size();
854 if (out_examples_count == 1) {
855 output_stats->feature_values_count += out.example_end_indices[0];
856 } else {
857 output_stats->feature_values_count +=
858 out.example_end_indices[out_examples_count - 1] -
859 out.example_end_indices[out_examples_count - 2];
860 }
861 }
862 }
863 }
864
865 // Handle missing dense features for fixed strides.
866 for (size_t d = 0; d < config.dense.size(); ++d) {
867 if (config.dense[d].variable_length) continue;
868 if (dense_feature_last_example[d] == example_index) continue;
869 if (config.dense[d].default_value.NumElements() == 0) {
870 return errors::InvalidArgument(
871 "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
872 " (data type: ", DataTypeString(config.dense[d].dtype), ")",
873 " is required but could not be found.");
874 }
875 const Tensor& in = config.dense[d].default_value;
876 Tensor& out = (*output_dense)[d];
877 const std::size_t num_elements = in.shape().num_elements();
878 const std::size_t offset = example_index * num_elements;
879
880 switch (config.dense[d].dtype) {
881 case DT_INT64: {
882 std::copy_n(in.flat<int64>().data(), num_elements,
883 out.flat<int64>().data() + offset);
884 break;
885 }
886 case DT_FLOAT: {
887 std::copy_n(in.flat<float>().data(), num_elements,
888 out.flat<float>().data() + offset);
889 break;
890 }
891 case DT_STRING: {
892 std::copy_n(in.flat<tstring>().data(), num_elements,
893 out.flat<tstring>().data() + offset);
894 break;
895 }
896 default:
897 LOG(FATAL) << "Should not happen.";
898 }
899 }
900
901 // Handle missing varlen dense features.
902 for (size_t d = 0; d < config.dense.size(); ++d) {
903 if (!config.dense[d].variable_length) continue;
904 if (dense_feature_last_example[d] == example_index) continue;
905 SparseBuffer& out = (*output_varlen_dense)[d];
906 size_t prev_example_end_index =
907 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
908 out.example_end_indices.push_back(prev_example_end_index);
909 }
910
911 // Handle missing sparse features.
912 for (size_t d = 0; d < config.sparse.size(); ++d) {
913 if (sparse_feature_last_example[d] == example_index) continue;
914 SparseBuffer& out = (*output_sparse)[d];
915 size_t prev_example_end_index =
916 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
917 out.example_end_indices.push_back(prev_example_end_index);
918 }
919
920 // Handle missing ragged features.
921 for (size_t d = 0; d < config.ragged.size(); ++d) {
922 if (ragged_feature_last_example[d] == example_index) continue;
923 SparseBuffer& out = (*output_ragged)[d];
924 size_t prev_example_end_index =
925 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
926 out.example_end_indices.push_back(prev_example_end_index);
927 }
928
929 return Status::OK();
930 }
931
CheckConfigDataType(DataType dtype)932 Status CheckConfigDataType(DataType dtype) {
933 switch (dtype) {
934 case DT_INT64:
935 case DT_FLOAT:
936 case DT_STRING:
937 return Status::OK();
938 default:
939 return errors::InvalidArgument("Invalid config dtype: ",
940 DataTypeString(dtype));
941 }
942 }
943
944 // Use this in the "default" clause of switch statements when dispatching
945 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)946 inline void ReportUnexpectedDataType(DataType dtype) {
947 DCHECK(false)
948 << "Encountered unexpected DataType " << DataTypeString(dtype)
949 << "in variable that should have been checked by CheckConfigDataType().";
950 }
951
CheckConfigDataTypes(const Config & config)952 Status CheckConfigDataTypes(const Config& config) {
953 // Check config so we can safely CHECK(false) in switches on config.*.dtype
954 for (auto& c : config.sparse) {
955 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
956 }
957 for (auto& c : config.dense) {
958 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
959 }
960 for (auto& c : config.ragged) {
961 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
962 if (!(c.splits_dtype == DT_INT32 || c.splits_dtype == DT_INT64)) {
963 return errors::InvalidArgument("Invalid ragged_split_type: ",
964 DataTypeString(c.splits_dtype));
965 }
966 }
967 return Status::OK();
968 }
969
970 template <typename T>
971 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
972
973 template <>
GetListFromBuffer(const SparseBuffer & buffer)974 const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
975 return buffer.int64_list;
976 }
977 template <>
GetListFromBuffer(const SparseBuffer & buffer)978 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
979 return buffer.float_list;
980 }
981 template <>
GetListFromBuffer(const SparseBuffer & buffer)982 const SmallVector<tstring>& GetListFromBuffer<tstring>(
983 const SparseBuffer& buffer) {
984 return buffer.bytes_list;
985 }
986
987 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)988 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
989 std::copy(b, e, t);
990 }
991 template <>
CopyOrMoveBlock(const tstring * b,const tstring * e,tstring * t)992 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
993 std::move(b, e, t);
994 }
995
996 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const Config & config,const std::vector<std::vector<SparseBuffer>> & varlen_dense_buffers,Tensor * values)997 void FillAndCopyVarLen(
998 const int d, const size_t num_elements,
999 const size_t num_elements_per_minibatch, const Config& config,
1000 const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers,
1001 Tensor* values) {
1002 const Tensor& default_value = config.dense[d].default_value;
1003
1004 // Copy-fill the tensors (creating the zero/fill-padding)
1005 std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements,
1006 default_value.flat<T>()(0));
1007
1008 // Data is [batch_size, max_num_elements, data_stride_size]
1009 // and num_elements_per_minibatch = max_num_elements * data_stride_size
1010 auto data = values->flat<T>().data();
1011
1012 // Iterate over minibatch elements
1013 for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) {
1014 const SparseBuffer& buffer = varlen_dense_buffers[i][d];
1015 // Number of examples being stored in this buffer
1016 const auto& end_indices = buffer.example_end_indices;
1017 const size_t examples_in_buffer = end_indices.size();
1018 // const size_t stride_size = config.dense[d].elements_per_stride;
1019
1020 const auto& list = GetListFromBuffer<T>(buffer);
1021 auto list_ptr = list.begin();
1022
1023 size_t elements_tally = 0;
1024 // Iterate through all the examples stored in this buffer.
1025 for (size_t j = 0; j < examples_in_buffer; ++j) {
1026 // Number of elements stored for this example.
1027 const size_t num_elems = end_indices[j] - elements_tally;
1028 CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
1029 // Move forward this many elements in the varlen buffer.
1030 list_ptr += num_elems;
1031 // Move forward to the next minibatch entry in the values output.
1032 data += num_elements_per_minibatch;
1033 elements_tally = end_indices[j];
1034 }
1035 DCHECK(elements_tally == list.size());
1036 }
1037 }
1038
1039 // Thin vector like interface wrapper around a Tensor. This enable us to
1040 // directly populate a tensor during parsing instead of having to first create a
1041 // vactor and then copy the data over.
1042 template <typename T>
1043 class TensorVector {
1044 public:
1045 using value_type = T;
1046
tensor()1047 const Tensor& tensor() {
1048 if (!tensor_.has_value()) {
1049 resize(0);
1050 }
1051 return *tensor_;
1052 }
1053
size() const1054 int64 size() const {
1055 return tensor_.has_value() ? tensor_->NumElements() : 0;
1056 }
resize(int64 new_size)1057 void resize(int64 new_size) {
1058 DCHECK(!tensor_.has_value());
1059 tensor_ = Tensor(DataTypeToEnum<T>::v(), TensorShape({new_size}));
1060 data_ = tensor_->flat<T>().data();
1061 }
data()1062 T* data() { return data_; }
data() const1063 const T* data() const { return data_; }
1064
1065 private:
1066 // Use absl::optional to avoid calling the default constructor of Tensor
1067 // unnecessarily.
1068 absl::optional<Tensor> tensor_;
1069
1070 // Cached pointer to the raw data inside the tensor.
1071 T* data_ = nullptr;
1072 };
1073
CountSparseFeatures(const std::vector<std::vector<SparseBuffer>> & sparse_buffers,size_t d,size_t * total_num_features,size_t * max_num_features)1074 void CountSparseFeatures(
1075 const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
1076 size_t* total_num_features, size_t* max_num_features) {
1077 for (auto& sparse_values_tmp : sparse_buffers) {
1078 const std::vector<size_t>& end_indices =
1079 sparse_values_tmp[d].example_end_indices;
1080 *total_num_features += end_indices.back();
1081 *max_num_features = std::max(*max_num_features, end_indices[0]);
1082 for (size_t i = 1; i < end_indices.size(); ++i) {
1083 size_t example_size = end_indices[i] - end_indices[i - 1];
1084 *max_num_features = std::max(*max_num_features, example_size);
1085 }
1086 }
1087 }
1088
CopySparseBufferToTensor(DataType dtype,size_t offset,SparseBuffer * src,Tensor * dst)1089 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
1090 Tensor* dst) {
1091 switch (dtype) {
1092 case DT_INT64: {
1093 std::copy(src->int64_list.begin(), src->int64_list.end(),
1094 dst->flat<int64>().data() + offset);
1095 break;
1096 }
1097 case DT_FLOAT: {
1098 std::copy(src->float_list.begin(), src->float_list.end(),
1099 dst->flat<float>().data() + offset);
1100 break;
1101 }
1102 case DT_STRING: {
1103 std::move(src->bytes_list.begin(), src->bytes_list.end(),
1104 dst->flat<tstring>().data() + offset);
1105 break;
1106 }
1107 default:
1108 ReportUnexpectedDataType(dtype);
1109 }
1110 }
1111
1112 } // namespace
1113
FastParseExample(const Config & config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * result)1114 Status FastParseExample(const Config& config,
1115 gtl::ArraySlice<tstring> serialized,
1116 gtl::ArraySlice<tstring> example_names,
1117 thread::ThreadPool* thread_pool, Result* result) {
1118 DCHECK(result != nullptr);
1119 // Check config so we can safely CHECK(false) in switches on config.*.dtype
1120 TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1121
1122 if (config.collect_feature_stats) {
1123 result->feature_stats.resize(serialized.size());
1124 }
1125
1126 size_t config_size =
1127 config.dense.size() + config.sparse.size() + config.ragged.size();
1128 SeededHasher hasher;
1129 // Build config index.
1130 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1131 bool ok = true;
1132 for (size_t i = 0; i < 1000; ++i) {
1133 for (size_t d = 0; d < config.dense.size(); ++d) {
1134 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1135 {d, Type::Dense});
1136 }
1137 for (size_t d = 0; d < config.sparse.size(); ++d) {
1138 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1139 {d, Type::Sparse});
1140 }
1141 for (size_t d = 0; d < config.ragged.size(); ++d) {
1142 ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1143 {d, Type::Ragged});
1144 }
1145 if (ok) break;
1146 LOG(WARNING) << "Collision found. This should happen only if you have "
1147 "around 2^32 entries in your config.";
1148 hasher.seed++;
1149 config_index.Clear(config_size);
1150 ok = true;
1151 }
1152 if (!ok) {
1153 return errors::Internal(
1154 "Could not avoid collision. This should not happen.");
1155 }
1156
1157 // Allocate dense output for fixed length dense values
1158 // (variable-length dense and sparse and ragged have to be buffered).
1159 std::vector<Tensor> fixed_dense_values(config.dense.size());
1160 for (size_t d = 0; d < config.dense.size(); ++d) {
1161 if (config.dense[d].variable_length) continue;
1162 TensorShape out_shape;
1163 out_shape.AddDim(serialized.size());
1164 for (const int64 dim : config.dense[d].shape.dim_sizes()) {
1165 out_shape.AddDim(dim);
1166 }
1167 fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape);
1168 }
1169
1170 // This parameter affects performance in a big and data-dependent way.
1171 const size_t kMiniBatchSizeBytes = 50000;
1172
1173 // Calculate number of minibatches.
1174 // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1175 // Apply 'special logic' below for small and big regimes.
1176 const size_t num_minibatches = [&] {
1177 size_t result = 0;
1178 size_t minibatch_bytes = 0;
1179 for (size_t i = 0; i < serialized.size(); i++) {
1180 if (minibatch_bytes == 0) { // start minibatch
1181 result++;
1182 }
1183 minibatch_bytes += serialized[i].size() + 1;
1184 if (minibatch_bytes > kMiniBatchSizeBytes) {
1185 minibatch_bytes = 0;
1186 }
1187 }
1188 // 'special logic'
1189 const size_t min_minibatches = std::min<size_t>(8, serialized.size());
1190 const size_t max_minibatches = 64;
1191 return std::max<size_t>(min_minibatches,
1192 std::min<size_t>(max_minibatches, result));
1193 }();
1194
1195 auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
1196 return (serialized.size() * minibatch) / num_minibatches;
1197 };
1198
1199 // TODO(lew): A big performance low-hanging fruit here is to improve
1200 // num_minibatches calculation to take into account actual amount of work
1201 // needed, as the size in bytes is not perfect. Linear combination of
1202 // size in bytes and average number of features per example is promising.
1203 // Even better: measure time instead of estimating, but this is too costly
1204 // in small batches.
1205 // Maybe accept outside parameter #num_minibatches?
1206
1207 // Do minibatches in parallel.
1208 std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
1209 std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches);
1210 std::vector<std::vector<SparseBuffer>> ragged_buffers(num_minibatches);
1211 std::vector<Status> status_of_minibatch(num_minibatches);
1212 auto ProcessMiniBatch = [&](size_t minibatch) {
1213 sparse_buffers[minibatch].resize(config.sparse.size());
1214 varlen_dense_buffers[minibatch].resize(config.dense.size());
1215 ragged_buffers[minibatch].resize(config.ragged.size());
1216 size_t start = first_example_of_minibatch(minibatch);
1217 size_t end = first_example_of_minibatch(minibatch + 1);
1218 for (size_t e = start; e < end; ++e) {
1219 PerExampleFeatureStats* stats = nullptr;
1220 if (config.collect_feature_stats) {
1221 stats = &result->feature_stats[e];
1222 }
1223 status_of_minibatch[minibatch] = FastParseSerializedExample(
1224 serialized[e],
1225 (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
1226 config_index, hasher, &fixed_dense_values,
1227 &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch],
1228 &ragged_buffers[minibatch], stats);
1229 if (!status_of_minibatch[minibatch].ok()) break;
1230 }
1231 };
1232
1233 ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool);
1234
1235 for (Status& status : status_of_minibatch) {
1236 TF_RETURN_IF_ERROR(status);
1237 }
1238
1239 result->sparse_indices.reserve(config.sparse.size());
1240 result->sparse_values.reserve(config.sparse.size());
1241 result->sparse_shapes.reserve(config.sparse.size());
1242 result->dense_values.reserve(config.dense.size());
1243 result->ragged_values.reserve(config.ragged.size());
1244 result->ragged_splits.reserve(config.ragged.size());
1245
1246 for (size_t d = 0; d < config.dense.size(); ++d) {
1247 result->dense_values.push_back(std::move(fixed_dense_values[d]));
1248 }
1249
1250 // Merge SparseBuffers from all minibatches for every config.sparse.
1251 auto MergeSparseMinibatches = [&](size_t d) {
1252 // Loop over minibatches
1253 size_t total_num_features = 0;
1254 size_t max_num_features = 0;
1255 CountSparseFeatures(sparse_buffers, d, &total_num_features,
1256 &max_num_features);
1257
1258 TensorShape indices_shape;
1259 indices_shape.AddDim(total_num_features);
1260 indices_shape.AddDim(2);
1261 result->sparse_indices.emplace_back(DT_INT64, indices_shape);
1262 Tensor* indices = &result->sparse_indices.back();
1263
1264 TensorShape values_shape;
1265 values_shape.AddDim(total_num_features);
1266 result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape);
1267 Tensor* values = &result->sparse_values.back();
1268
1269 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2}));
1270 auto shapes_shape_t = result->sparse_shapes.back().vec<int64>();
1271 shapes_shape_t(0) = serialized.size();
1272 shapes_shape_t(1) = max_num_features;
1273
1274 size_t offset = 0;
1275 for (size_t i = 0; i < sparse_buffers.size(); ++i) {
1276 SparseBuffer& buffer = sparse_buffers[i][d];
1277
1278 // Update indices.
1279 size_t delta = 0;
1280
1281 if (indices->NumElements() > 0) {
1282 int64* ix_p = &indices->matrix<int64>()(offset, 0);
1283 size_t example_index = first_example_of_minibatch(i);
1284 for (size_t example_end_index : buffer.example_end_indices) {
1285 size_t feature_index = 0;
1286 for (; delta < example_end_index; ++delta) {
1287 // Column 0: example index
1288 *ix_p = example_index;
1289 // Column 1: the feature index buffer example
1290 *(ix_p + 1) = feature_index;
1291 ix_p += 2;
1292 ++feature_index;
1293 }
1294 ++example_index;
1295 }
1296 }
1297
1298 CopySparseBufferToTensor(config.sparse[d].dtype, offset, &buffer, values);
1299 offset += delta;
1300 }
1301 };
1302
1303 // Merge SparseBuffers from all minibatches for every config.ragged.
1304 auto MergeRaggedMinibatches = [&](size_t d) {
1305 // Loop over minibatches
1306 size_t total_num_features = 0;
1307 size_t max_num_features = 0;
1308 CountSparseFeatures(ragged_buffers, d, &total_num_features,
1309 &max_num_features);
1310
1311 TensorShape row_splits_shape;
1312 row_splits_shape.AddDim(serialized.size() + 1);
1313 result->ragged_splits.emplace_back(config.ragged[d].splits_dtype,
1314 row_splits_shape);
1315 Tensor* row_splits = &result->ragged_splits.back();
1316 if (config.ragged[d].splits_dtype == DT_INT64) {
1317 row_splits->flat<int64>()(0) = 0;
1318 } else {
1319 row_splits->flat<int32>()(0) = 0;
1320 }
1321
1322 TensorShape values_shape;
1323 values_shape.AddDim(total_num_features);
1324 result->ragged_values.emplace_back(config.ragged[d].dtype, values_shape);
1325 Tensor* values = &result->ragged_values.back();
1326
1327 size_t values_offset = 0;
1328 size_t splits_offset = 0;
1329 for (size_t i = 0; i < ragged_buffers.size(); ++i) {
1330 SparseBuffer& buffer = ragged_buffers[i][d];
1331 if (buffer.example_end_indices.empty()) continue;
1332
1333 // Update row_splits. row_splits are formed by concatenating the example
1334 // end_indices (adjusting each to start after the previous one ends).
1335 if (config.ragged[d].splits_dtype == DT_INT64) {
1336 int64* row_splits_out = &row_splits->flat<int64>()(splits_offset);
1337 int64 start = *row_splits_out;
1338 for (size_t example_end_index : buffer.example_end_indices) {
1339 *++row_splits_out = start + example_end_index;
1340 }
1341 } else {
1342 int32* row_splits_out = &row_splits->flat<int32>()(splits_offset);
1343 int32 start = *row_splits_out;
1344 for (size_t example_end_index : buffer.example_end_indices) {
1345 *++row_splits_out = start + example_end_index;
1346 }
1347 }
1348
1349 CopySparseBufferToTensor(config.ragged[d].dtype, values_offset, &buffer,
1350 values);
1351 values_offset += buffer.example_end_indices.back();
1352 splits_offset += buffer.example_end_indices.size();
1353 }
1354 };
1355
1356 // Merge SparseBuffers from all minibatches for every config.dense having
1357 // variable_length.
1358 auto MergeDenseVarLenMinibatches = [&](size_t d) {
1359 if (!config.dense[d].variable_length) return;
1360
1361 // Loop over minibatches
1362 size_t max_num_features = 0;
1363 for (auto& dense_values_tmp : varlen_dense_buffers) {
1364 std::vector<size_t>& end_indices =
1365 dense_values_tmp[d].example_end_indices;
1366 max_num_features = std::max(max_num_features, end_indices[0]);
1367 for (size_t i = 1; i < end_indices.size(); ++i) {
1368 size_t example_size = end_indices[i] - end_indices[i - 1];
1369 max_num_features = std::max(max_num_features, example_size);
1370 }
1371 }
1372
1373 const size_t stride_size = config.dense[d].elements_per_stride;
1374 const size_t max_num_elements = max_num_features / stride_size;
1375 TensorShape values_shape;
1376 DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
1377 const size_t batch_size = serialized.size();
1378 values_shape.AddDim(batch_size);
1379 values_shape.AddDim(max_num_elements);
1380 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1381 values_shape.AddDim(config.dense[d].shape.dim_size(i));
1382 }
1383 Tensor values(config.dense[d].dtype, values_shape);
1384 result->dense_values[d] = values;
1385 const size_t num_elements = values.NumElements();
1386
1387 // Nothing to write, exit early.
1388 if (num_elements == 0) return;
1389
1390 const size_t num_elements_per_minibatch = num_elements / batch_size;
1391
1392 switch (config.dense[d].dtype) {
1393 case DT_INT64: {
1394 FillAndCopyVarLen<int64>(d, num_elements, num_elements_per_minibatch,
1395 config, varlen_dense_buffers, &values);
1396 break;
1397 }
1398 case DT_FLOAT: {
1399 FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
1400 config, varlen_dense_buffers, &values);
1401 break;
1402 }
1403 case DT_STRING: {
1404 FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch,
1405 config, varlen_dense_buffers, &values);
1406 break;
1407 }
1408 default:
1409 ReportUnexpectedDataType(config.dense[d].dtype);
1410 }
1411 };
1412
1413 for (size_t d = 0; d < config.dense.size(); ++d) {
1414 MergeDenseVarLenMinibatches(d);
1415 }
1416
1417 for (size_t d = 0; d < config.sparse.size(); ++d) {
1418 MergeSparseMinibatches(d);
1419 }
1420
1421 for (size_t d = 0; d < config.ragged.size(); ++d) {
1422 MergeRaggedMinibatches(d);
1423 }
1424
1425 return Status::OK();
1426 }
1427
FastParseSingleExample(const Config & config,StringPiece serialized,Result * result)1428 Status FastParseSingleExample(const Config& config, StringPiece serialized,
1429 Result* result) {
1430 DCHECK(result != nullptr);
1431 // Check config so we can safely CHECK(false) in switches on config.*.dtype
1432 TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1433
1434 PerExampleFeatureStats* stats = nullptr;
1435 if (config.collect_feature_stats) {
1436 result->feature_stats.emplace_back();
1437 stats = &result->feature_stats.back();
1438 }
1439
1440 // TODO(mrry): Cache the construction of this map at Op construction time.
1441 size_t config_size =
1442 config.dense.size() + config.sparse.size() + config.ragged.size();
1443 SeededHasher hasher;
1444 // Build config index.
1445 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1446 bool ok = true;
1447 for (size_t i = 0; i < 1000; ++i) {
1448 for (size_t d = 0; d < config.dense.size(); ++d) {
1449 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1450 {d, Type::Dense});
1451 }
1452 for (size_t d = 0; d < config.sparse.size(); ++d) {
1453 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1454 {d, Type::Sparse});
1455 }
1456 for (size_t d = 0; d < config.ragged.size(); ++d) {
1457 ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1458 {d, Type::Ragged});
1459 }
1460 if (ok) break;
1461 LOG(WARNING) << "Collision found. This should happen only if you have "
1462 "around 2^32 entries in your config.";
1463 hasher.seed++;
1464 config_index.Clear(config_size);
1465 ok = true;
1466 }
1467 if (!ok) {
1468 return errors::Internal(
1469 "Could not avoid collision. This should not happen.");
1470 }
1471
1472 result->sparse_indices.reserve(config.sparse.size());
1473 result->sparse_values.reserve(config.sparse.size());
1474 result->sparse_shapes.reserve(config.sparse.size());
1475 result->dense_values.reserve(config.dense.size());
1476 result->ragged_values.reserve(config.ragged.size());
1477 result->ragged_splits.reserve(config.ragged.size());
1478
1479 // Allocate dense output tensors.
1480 for (size_t d = 0; d < config.dense.size(); ++d) {
1481 if (!config.dense[d].variable_length) {
1482 TensorShape values_shape;
1483 if (!config.dense[d].shape.AsTensorShape(&values_shape)) {
1484 return errors::Internal(
1485 "Fixed-length shape was not a statically defined shape.");
1486 }
1487 result->dense_values.emplace_back(config.dense[d].dtype, values_shape);
1488 } else {
1489 // Variable-length tensor will be allocated later.
1490 result->dense_values.emplace_back();
1491 }
1492 }
1493
1494 // Allocate sparse output tensors.
1495 for (size_t d = 0; d < config.sparse.size(); ++d) {
1496 // The dense_shape is always a vector of length 1.
1497 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1}));
1498 // Variable-length tensors will be allocated later.
1499 result->sparse_indices.emplace_back();
1500 result->sparse_values.emplace_back();
1501 }
1502
1503 // Allocate ragged output tensors.
1504 for (size_t d = 0; d < config.ragged.size(); ++d) {
1505 // Variable-length values tensors will be allocated later.
1506 result->ragged_values.emplace_back();
1507 // Splits tensors are empty (unused) for single (scalar) inputs.
1508 const auto splits_dtype = config.ragged[d].splits_dtype;
1509 result->ragged_splits.emplace_back(splits_dtype, TensorShape({0}));
1510 }
1511
1512 parsed::Example parsed_example;
1513 if (!ParseExample(serialized, &parsed_example)) {
1514 return errors::InvalidArgument("Could not parse example input, value: '",
1515 serialized, "'");
1516 }
1517 std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false);
1518 std::vector<bool> dense_feature_already_seen(config.dense.size(), false);
1519 std::vector<bool> ragged_feature_already_seen(config.ragged.size(), false);
1520
1521 if (stats) {
1522 // TODO(b/111553342): This may over-count the number of features if there
1523 // are duplicate keys in the feature map. Consider deduplicating the keys
1524 // before computing the count.
1525 stats->features_count = parsed_example.size();
1526 }
1527
1528 // Handle features present in the example.
1529 const size_t parsed_example_size = parsed_example.size();
1530 for (size_t i = 0; i < parsed_example_size; ++i) {
1531 // This is a logic that standard protobuf parsing is implementing.
1532 // I.e. last entry in the map overwrites all the previous ones.
1533 parsed::FeatureMapEntry& name_and_feature =
1534 parsed_example[parsed_example_size - i - 1];
1535
1536 const StringPiece feature_name = name_and_feature.first;
1537 parsed::Feature& feature = name_and_feature.second;
1538
1539 std::pair<size_t, Type> d_and_type;
1540 uint64 h = hasher(feature_name);
1541 if (!config_index.Find(h, &d_and_type)) continue;
1542
1543 size_t d = d_and_type.first;
1544 bool is_dense = d_and_type.second == Type::Dense;
1545 bool is_sparse = d_and_type.second == Type::Sparse;
1546
1547 {
1548 // Testing for PresizedCuckooMap collision.
1549 // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
1550 const tstring& config_feature_name =
1551 is_dense ? config.dense[d].feature_name
1552 : (is_sparse ? config.sparse[d].feature_name
1553 : config.ragged[d].feature_name);
1554 if (feature_name != config_feature_name) continue;
1555 }
1556
1557 auto example_error = [feature_name](StringPiece suffix) {
1558 return errors::InvalidArgument("Key: ", feature_name, ". ", suffix);
1559 };
1560
1561 auto parse_error = [feature_name] {
1562 return errors::InvalidArgument("Key: ", feature_name,
1563 ". Can't parse serialized Example.");
1564 };
1565
1566 DataType example_dtype;
1567 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
1568 if (example_dtype == DT_INVALID) continue;
1569
1570 if (is_dense && !config.dense[d].variable_length) {
1571 // If feature was already visited, skip.
1572 // Compare comment at the beginning of the loop.
1573 if (dense_feature_already_seen[d]) {
1574 LogDenseFeatureDataLoss(feature_name);
1575 continue;
1576 }
1577 dense_feature_already_seen[d] = true;
1578
1579 if (example_dtype != config.dense[d].dtype) {
1580 return example_error(strings::StrCat(
1581 "Data types don't match. Data type: ",
1582 DataTypeString(example_dtype),
1583 " but expected type: ", DataTypeString(config.dense[d].dtype)));
1584 }
1585
1586 Tensor* out = &result->dense_values[d];
1587 const std::size_t num_elements = config.dense[d].elements_per_stride;
1588 if (stats) {
1589 // TODO(b/111553342): If desirable, we could add support for counting
1590 // elements in the features that aren't parsed, but this could add
1591 // considerable runtime cost.
1592 stats->feature_values_count += num_elements;
1593 }
1594 switch (example_dtype) {
1595 case DT_INT64: {
1596 auto out_p = out->flat<int64>().data();
1597 LimitedArraySlice<int64> slice(out_p, num_elements);
1598 if (!feature.ParseInt64List(&slice)) return parse_error();
1599 if (slice.EndDistance() != 0) {
1600 return parse_error();
1601 }
1602 break;
1603 }
1604 case DT_FLOAT: {
1605 auto out_p = out->flat<float>().data();
1606 LimitedArraySlice<float> slice(out_p, num_elements);
1607 if (!feature.ParseFloatList(&slice)) return parse_error();
1608 if (slice.EndDistance() != 0) {
1609 return parse_error();
1610 }
1611 break;
1612 }
1613 case DT_STRING: {
1614 auto out_p = out->flat<tstring>().data();
1615 LimitedArraySlice<tstring> slice(out_p, num_elements);
1616 if (!feature.ParseBytesList(&slice)) return parse_error();
1617 if (slice.EndDistance() != 0) {
1618 return parse_error();
1619 }
1620 break;
1621 }
1622 default:
1623 ReportUnexpectedDataType(example_dtype);
1624 }
1625
1626 } else { // if variable length
1627 SmallVector<tstring> bytes_list;
1628 TensorVector<float> float_list;
1629 SmallVector<int64> int64_list;
1630
1631 const size_t num_elements_divisor =
1632 is_dense ? config.dense[d].elements_per_stride : 1;
1633 size_t num_elements;
1634
1635 if (is_dense) {
1636 // If feature was already visited, skip.
1637 // Compare comment at the beginning of the loop.
1638 if (dense_feature_already_seen[d]) {
1639 LogDenseFeatureDataLoss(feature_name);
1640 continue;
1641 }
1642 dense_feature_already_seen[d] = true;
1643 if (example_dtype != config.dense[d].dtype) {
1644 return example_error(strings::StrCat(
1645 "Data types don't match. Data type: ",
1646 DataTypeString(example_dtype),
1647 " but expected type: ", DataTypeString(config.dense[d].dtype)));
1648 }
1649 } else {
1650 // Feature is sparse or ragged.
1651 auto& feature_already_seen = is_sparse ? sparse_feature_already_seen
1652 : ragged_feature_already_seen;
1653 auto& feature_dtype =
1654 is_sparse ? config.sparse[d].dtype : config.ragged[d].dtype;
1655 // If feature was already visited, skip.
1656 // Compare comment at the beginning of the loop.
1657 if (feature_already_seen[d]) {
1658 LogSparseFeatureDataLoss(feature_name);
1659 continue;
1660 }
1661 feature_already_seen[d] = true;
1662
1663 // Handle sparse features.
1664 if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
1665 return example_error(strings::StrCat(
1666 "Data types don't match. ",
1667 "Expected type: ", DataTypeString(feature_dtype),
1668 ", Actual type: ", DataTypeString(example_dtype)));
1669 }
1670 }
1671
1672 switch (example_dtype) {
1673 case DT_INT64: {
1674 // TODO(mrry): Use the fact that the `int64_list` is packed to read
1675 // out the length and pre-allocate the output tensor.
1676 if (!feature.ParseInt64List(&int64_list)) return parse_error();
1677 num_elements = int64_list.size();
1678 break;
1679 }
1680 case DT_FLOAT: {
1681 if (!feature.ParseFloatList(&float_list)) return parse_error();
1682 num_elements = float_list.size();
1683 break;
1684 }
1685 case DT_STRING: {
1686 int actual_num_elements = 0;
1687 if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1688 return parse_error();
1689 }
1690 bytes_list.reserve(actual_num_elements);
1691 if (!feature.ParseBytesList(&bytes_list)) return parse_error();
1692 num_elements = bytes_list.size();
1693 break;
1694 }
1695 default:
1696 num_elements = 0;
1697 ReportUnexpectedDataType(example_dtype);
1698 }
1699
1700 if (num_elements % num_elements_divisor != 0) {
1701 return parse_error();
1702 }
1703
1704 if (stats) {
1705 stats->feature_values_count += num_elements;
1706 }
1707
1708 Tensor* out;
1709 DataType out_dtype;
1710 TensorShape out_shape;
1711 if (is_dense) {
1712 out_shape.AddDim(num_elements / num_elements_divisor);
1713 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1714 out_shape.AddDim(config.dense[d].shape.dim_size(i));
1715 }
1716
1717 out = &result->dense_values[d];
1718 out_dtype = config.dense[d].dtype;
1719 } else if (is_sparse) {
1720 Tensor* out_indices = &result->sparse_indices[d];
1721 Tensor* out_dense_shape = &result->sparse_shapes[d];
1722
1723 // TODO(mrry): Investigate the possibility of not materializing
1724 // the indices (and perhaps dense_shape) until they are needed.
1725 *out_indices = Tensor(
1726 DT_INT64, TensorShape({static_cast<int64>(num_elements), 1}));
1727 auto indices_flat = out_indices->flat<int64>();
1728 for (size_t i = 0; i < num_elements; ++i) {
1729 indices_flat(i) = static_cast<int64>(i);
1730 }
1731
1732 *out_dense_shape = Tensor(DT_INT64, TensorShape({1}));
1733 auto shapes_shape_t = out_dense_shape->vec<int64>();
1734 shapes_shape_t(0) = num_elements;
1735
1736 out = &result->sparse_values[d];
1737 out_dtype = config.sparse[d].dtype;
1738 out_shape.AddDim(num_elements);
1739 } else {
1740 out = &result->ragged_values[d];
1741 out_dtype = config.ragged[d].dtype;
1742 out_shape.AddDim(num_elements);
1743 }
1744
1745 switch (example_dtype) {
1746 case DT_INT64: {
1747 *out = Tensor(out_dtype, out_shape);
1748 CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
1749 out->flat<int64>().data());
1750 break;
1751 }
1752 case DT_FLOAT: {
1753 if (!out->CopyFrom(float_list.tensor(), out_shape)) {
1754 return parse_error();
1755 }
1756 break;
1757 }
1758 case DT_STRING: {
1759 *out = Tensor(out_dtype, out_shape);
1760 CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(),
1761 out->flat<tstring>().data());
1762 break;
1763 }
1764 default:
1765 ReportUnexpectedDataType(example_dtype);
1766 }
1767 }
1768 }
1769
1770 // Handle missing dense features.
1771 for (size_t d = 0; d < config.dense.size(); ++d) {
1772 if (!dense_feature_already_seen[d]) {
1773 if (!config.dense[d].variable_length) {
1774 // Handle missing fixed-length dense feature.
1775 if (config.dense[d].default_value.NumElements() == 0) {
1776 return errors::InvalidArgument(
1777 "Feature: ", config.dense[d].feature_name,
1778 " (data type: ", DataTypeString(config.dense[d].dtype), ")",
1779 " is required but could not be found.");
1780 }
1781 result->dense_values[d] = config.dense[d].default_value;
1782 } else {
1783 // Handle missing varlen dense feature.
1784 TensorShape empty_shape;
1785 empty_shape.AddDim(0);
1786 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1787 empty_shape.AddDim(config.dense[d].shape.dim_size(i));
1788 }
1789 result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape);
1790 }
1791 }
1792 }
1793
1794 // Handle missing sparse features.
1795 for (size_t d = 0; d < config.sparse.size(); ++d) {
1796 if (!sparse_feature_already_seen[d]) {
1797 result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1}));
1798 result->sparse_values[d] =
1799 Tensor(config.sparse[d].dtype, TensorShape({0}));
1800 result->sparse_shapes[d].vec<int64>()(0) = 0;
1801 }
1802 }
1803
1804 // Handle missing ragged features.
1805 for (size_t d = 0; d < config.ragged.size(); ++d) {
1806 if (!ragged_feature_already_seen[d]) {
1807 result->ragged_values[d] =
1808 Tensor(config.ragged[d].dtype, TensorShape({0}));
1809 }
1810 }
1811
1812 return Status::OK();
1813 }
1814
1815 // Private helper functions for FastParseSequenceExample.
1816 namespace {
1817
1818 // A struct used by FastParseSequenceExample to hold the serialized proto
1819 // substrings for a single feature, plus some auxiliary information derived
1820 // from those protos (such as the total value length).
1821 struct FeatureProtos {
1822 // Proto substrings from each serialized SequenceExample that correspond
1823 // with this feature. `protos_present` records whether the proto had a
1824 // value defined (even if that value is empty).
1825 std::vector<StringPiece> protos;
1826 std::vector<bool> protos_present;
1827
1828 // Information derived from protos:
1829 size_t length; // total length for ragged/sparse, max row length for dense.
1830 size_t num_rows; // only populated for ragged sequence features.
1831
1832 // Information from the config:
1833 Type type; // Whether this feature is sparse, ragged, or dense.
1834 DataType dtype;
1835 };
1836
1837 // Map from feature name to FeatureProtos for that feature.
1838 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
1839
ExampleName(const gtl::ArraySlice<tstring> example_names,int n)1840 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
1841 return example_names.empty() ? "<unknown>" : example_names[n];
1842 }
1843
1844 // Return the number of bytes elements parsed, or -1 on error. If out is null,
1845 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)1846 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
1847 tstring* out) {
1848 int num_elements = 0;
1849 uint32 length;
1850 if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
1851 return -1;
1852 }
1853 if (length > 0) {
1854 auto limit = stream->PushLimit(length);
1855 while (!stream->ExpectAtEnd()) {
1856 uint32 bytes_length;
1857 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1858 !stream->ReadVarint32(&bytes_length)) {
1859 return -1;
1860 }
1861 if (out == nullptr) {
1862 stream->Skip(bytes_length);
1863 } else {
1864 out->resize_uninitialized(bytes_length);
1865 if (!stream->ReadRaw(out->data(), bytes_length)) {
1866 return -1;
1867 }
1868 out++;
1869 }
1870 num_elements++;
1871 }
1872 stream->PopLimit(limit);
1873 }
1874 return num_elements;
1875 }
1876
PadFloatFeature(int num_to_pad,float * out)1877 inline void PadFloatFeature(int num_to_pad, float* out) {
1878 for (int i = 0; i < num_to_pad; i++) {
1879 *out++ = 0.0;
1880 }
1881 }
1882
PadInt64Feature(int num_to_pad,int64 * out)1883 inline void PadInt64Feature(int num_to_pad, int64* out) {
1884 for (int i = 0; i < num_to_pad; i++) {
1885 *out++ = 0;
1886 }
1887 }
1888
1889 // Return the number of float elements parsed, or -1 on error. If out is null,
1890 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)1891 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
1892 float* out) {
1893 int num_elements = 0;
1894 uint32 length;
1895 if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
1896 return -1;
1897 }
1898 if (length > 0) {
1899 auto limit = stream->PushLimit(length);
1900 uint8 peek_tag = PeekTag(stream);
1901 if (peek_tag == kDelimitedTag(1)) { // packed
1902 uint32 packed_length;
1903 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1904 !stream->ReadVarint32(&packed_length)) {
1905 return -1;
1906 }
1907 auto packed_limit = stream->PushLimit(packed_length);
1908 while (!stream->ExpectAtEnd()) {
1909 uint32 buffer32;
1910 if (!stream->ReadLittleEndian32(&buffer32)) {
1911 return -1;
1912 }
1913 if (out != nullptr) {
1914 *out++ = absl::bit_cast<float>(buffer32);
1915 }
1916 num_elements++;
1917 }
1918 stream->PopLimit(packed_limit);
1919 } else if (peek_tag == kFixed32Tag(1)) {
1920 while (!stream->ExpectAtEnd()) {
1921 uint32 buffer32;
1922 if (!stream->ExpectTag(kFixed32Tag(1)) ||
1923 !stream->ReadLittleEndian32(&buffer32)) {
1924 return -1;
1925 }
1926 if (out != nullptr) {
1927 *out++ = absl::bit_cast<float>(buffer32);
1928 }
1929 num_elements++;
1930 }
1931 } else {
1932 // Unknown tag.
1933 return -1;
1934 }
1935 stream->PopLimit(limit);
1936 }
1937 return num_elements;
1938 }
1939
1940 // Return the number of int64 elements parsed, or -1 on error. If out is null,
1941 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64 * out)1942 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
1943 int64* out) {
1944 int num_elements = 0;
1945 uint32 length;
1946 if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
1947 return -1;
1948 }
1949 if (length > 0) {
1950 auto limit = stream->PushLimit(length);
1951 uint8 peek_tag = PeekTag(stream);
1952 if (peek_tag == kDelimitedTag(1)) { // packed
1953 uint32 packed_length;
1954 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1955 !stream->ReadVarint32(&packed_length)) {
1956 return -1;
1957 }
1958 auto packed_limit = stream->PushLimit(packed_length);
1959 while (!stream->ExpectAtEnd()) {
1960 protobuf_uint64 n; // There is no API for int64
1961 if (!stream->ReadVarint64(&n)) {
1962 return -1;
1963 }
1964 if (out != nullptr) {
1965 *out++ = n;
1966 }
1967 num_elements++;
1968 }
1969 stream->PopLimit(packed_limit);
1970 } else if (peek_tag == kVarintTag(1)) {
1971 while (!stream->ExpectAtEnd()) {
1972 protobuf_uint64 n; // There is no API for int64
1973 if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
1974 return -1;
1975 }
1976 if (out != nullptr) {
1977 *out++ = n;
1978 }
1979 num_elements++;
1980 }
1981 } else {
1982 // Unknown tag.
1983 return -1;
1984 }
1985 stream->PopLimit(limit);
1986 }
1987 return num_elements;
1988 }
1989
1990 // Parses the next feature on `stream` into `out` starting at `out_offset`.
1991 // Updates `out_offset`, and returns the number of values added.
1992 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)1993 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
1994 Tensor* out, size_t* out_offset) {
1995 int delta;
1996 switch (dtype) {
1997 case DT_STRING:
1998 delta =
1999 ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
2000 break;
2001 case DT_FLOAT:
2002 delta =
2003 ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
2004 break;
2005 case DT_INT64:
2006 delta =
2007 ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
2008 break;
2009 default:
2010 ReportUnexpectedDataType(dtype);
2011 delta = 0;
2012 }
2013 if (delta > 0) {
2014 *out_offset += delta;
2015 }
2016 return delta;
2017 }
2018
2019 // Returns the length of the next feature on `stream`.
2020 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)2021 inline int GetFeatureLength(DataType dtype,
2022 protobuf::io::CodedInputStream* stream) {
2023 switch (dtype) {
2024 case DT_STRING:
2025 return ParseBytesFeature(stream, nullptr);
2026 case DT_FLOAT:
2027 return ParseFloatFeature(stream, nullptr);
2028 case DT_INT64:
2029 return ParseInt64Feature(stream, nullptr);
2030 default:
2031 ReportUnexpectedDataType(dtype);
2032 return -1;
2033 }
2034 }
2035
ParseDataType(protobuf::io::CodedInputStream * stream)2036 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
2037 uint8 peek_tag = PeekTag(stream);
2038 switch (peek_tag) {
2039 case kDelimitedTag(1):
2040 return DT_STRING;
2041 case kDelimitedTag(2):
2042 return DT_FLOAT;
2043 case kDelimitedTag(3):
2044 return DT_INT64;
2045 default:
2046 return DT_INVALID;
2047 }
2048 }
2049
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)2050 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
2051 DataType dtype) {
2052 switch (dtype) {
2053 case DT_STRING:
2054 if (!stream->ExpectTag(kDelimitedTag(1))) {
2055 return false;
2056 }
2057 break;
2058 case DT_FLOAT:
2059 if (!stream->ExpectTag(kDelimitedTag(2))) {
2060 return false;
2061 }
2062 break;
2063 case DT_INT64:
2064 if (!stream->ExpectTag(kDelimitedTag(3))) {
2065 return false;
2066 }
2067 break;
2068 default:
2069 return false;
2070 }
2071 uint32 length;
2072 return stream->ReadVarint32(&length) && length == 0;
2073 }
2074
2075 // Reads an example proto, and extracts a StringPiece pointer to each feature.
ExtractFeaturesFromSequenceExamples(const gtl::ArraySlice<tstring> examples,const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features,FeatureProtosMap * sequence_features)2076 Status ExtractFeaturesFromSequenceExamples(
2077 const gtl::ArraySlice<tstring> examples,
2078 const gtl::ArraySlice<tstring> example_names,
2079 FeatureProtosMap* context_features, FeatureProtosMap* sequence_features) {
2080 for (int d = 0; d < examples.size(); d++) {
2081 const tstring& example = examples[d];
2082 protobuf::io::CodedInputStream stream(
2083 reinterpret_cast<const uint8*>(example.data()), example.size());
2084 // Not clear what this does. Why not stream.EnableAliasing()?
2085 EnableAliasing(&stream);
2086
2087 // Extract pointers to all features within this serialized example.
2088 while (!stream.ExpectAtEnd()) {
2089 FeatureProtosMap* features = nullptr;
2090 if (stream.ExpectTag(kDelimitedTag(1))) {
2091 // Context
2092 features = context_features;
2093 } else if (stream.ExpectTag(kDelimitedTag(2))) {
2094 // Sequence
2095 features = sequence_features;
2096 } else if (!SkipExtraneousTag(&stream)) {
2097 return errors::InvalidArgument(
2098 "Invalid protocol message input, example id: ",
2099 ExampleName(example_names, d));
2100 }
2101 if (features != nullptr) {
2102 uint32 length;
2103 if (!stream.ReadVarint32(&length)) {
2104 return errors::InvalidArgument(
2105 "Invalid protocol message input, example id: ",
2106 ExampleName(example_names, d));
2107 }
2108 auto limit = stream.PushLimit(length);
2109 while (!stream.ExpectAtEnd()) {
2110 StringPiece key, value;
2111 uint32 length;
2112 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2113 !stream.ReadVarint32(&length)) {
2114 return errors::InvalidArgument(
2115 "Invalid protocol message input, example id: ",
2116 ExampleName(example_names, d));
2117 }
2118 auto limit = stream.PushLimit(length);
2119 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2120 !ParseString(&stream, &key) ||
2121 !stream.ExpectTag(kDelimitedTag(2)) ||
2122 !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
2123 return errors::InvalidArgument(
2124 "Invalid protocol message input, example id: ",
2125 ExampleName(example_names, d));
2126 }
2127 stream.PopLimit(limit);
2128 // Only save if this feature was requested.
2129 auto feature_iter = features->find(key);
2130 if (feature_iter != features->end()) {
2131 auto& feature = feature_iter->second;
2132 feature.protos[d] = value;
2133 feature.protos_present[d] = true;
2134 }
2135 }
2136 stream.PopLimit(limit);
2137 }
2138 }
2139 }
2140 return Status::OK();
2141 }
2142
2143 // Populates context_features[k].length based on context_features[k].protos
2144 // (for all k).
GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features)2145 Status GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2146 FeatureProtosMap* context_features) {
2147 for (auto& c : *context_features) {
2148 FeatureProtos& feature = c.second;
2149 for (int d = 0; d < feature.protos.size(); ++d) {
2150 const auto& proto = feature.protos[d];
2151 if (proto.empty()) continue;
2152 protobuf::io::CodedInputStream stream(
2153 reinterpret_cast<const uint8*>(proto.data()), proto.size());
2154 EnableAliasing(&stream);
2155 int num_elements = GetFeatureLength(feature.dtype, &stream);
2156 if (num_elements < 0) {
2157 return errors::InvalidArgument(
2158 "Name: ", ExampleName(example_names, d),
2159 ", Context feature: ", c.first,
2160 ". Data types don't match. Expected type: ",
2161 DataTypeString(feature.dtype));
2162 }
2163 switch (feature.type) {
2164 case Type::Sparse: // intentional fall-through
2165 case Type::Ragged:
2166 feature.length += num_elements;
2167 break;
2168 case Type::Dense:
2169 feature.length =
2170 std::max(feature.length, static_cast<size_t>(num_elements));
2171 break;
2172 }
2173 }
2174 }
2175 return Status::OK();
2176 }
2177
2178 // Populates sequence_features[k].length and sequence_features[k].num_rows based
2179 // on sequence_features[k].protos (for all k).
GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * sequence_features)2180 Status GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2181 FeatureProtosMap* sequence_features) {
2182 for (auto& c : *sequence_features) {
2183 FeatureProtos& feature = c.second;
2184 for (int d = 0; d < feature.protos.size(); ++d) {
2185 const auto& proto = feature.protos[d];
2186 if (proto.empty()) continue;
2187
2188 size_t num_rows = 0;
2189 size_t num_elements = 0;
2190 protobuf::io::CodedInputStream stream(
2191 reinterpret_cast<const uint8*>(proto.data()), proto.size());
2192 EnableAliasing(&stream);
2193 while (!stream.ExpectAtEnd()) {
2194 uint32 feature_bytes;
2195 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2196 !stream.ReadVarint32(&feature_bytes)) {
2197 return errors::InvalidArgument("Error in sequence feature ", c.first,
2198 " in example ",
2199 ExampleName(example_names, d));
2200 }
2201 if (feature_bytes > 2) {
2202 auto limit = stream.PushLimit(feature_bytes);
2203 int delta = GetFeatureLength(feature.dtype, &stream);
2204 if (delta < 0) {
2205 return errors::InvalidArgument(
2206 "Name: ", ExampleName(example_names, d),
2207 ", Feature list: ", c.first, ", Index: ", num_rows,
2208 ". Data types don't match. Expected type: ",
2209 DataTypeString(feature.dtype));
2210 }
2211 num_elements += delta;
2212 stream.PopLimit(limit);
2213 } else if (feature_bytes == 2) {
2214 if (!SkipEmptyFeature(&stream, feature.dtype)) {
2215 return errors::InvalidArgument(
2216 "Name: ", ExampleName(example_names, d),
2217 ", Feature list: ", c.first, ", Index: ", num_rows,
2218 ". Data types don't match. Expected type: ",
2219 DataTypeString(feature.dtype));
2220 }
2221 } else if (feature_bytes != 0) {
2222 return errors::InvalidArgument("Error in sequence feature ", c.first,
2223 " in example ",
2224 ExampleName(example_names, d));
2225 }
2226 ++num_rows;
2227 }
2228 switch (feature.type) {
2229 case Type::Sparse:
2230 feature.length += num_elements;
2231 break;
2232 case Type::Ragged:
2233 feature.length += num_elements;
2234 feature.num_rows += num_rows;
2235 break;
2236 case Type::Dense:
2237 feature.length = std::max(feature.length, num_elements);
2238 break;
2239 }
2240 }
2241 }
2242 return Status::OK();
2243 }
2244
2245 // Copies src into dst[dst_offset:dst_offset+src.size], and then increments
2246 // dst_offset by src.size.
CopyTensorIntoTensor(DataType dtype,const Tensor & src,Tensor * dst,size_t * dst_offset)2247 void CopyTensorIntoTensor(DataType dtype, const Tensor& src, Tensor* dst,
2248 size_t* dst_offset) {
2249 size_t src_size = src.NumElements();
2250 switch (dtype) {
2251 case DT_INT64: {
2252 auto src_t = src.flat<int64>().data();
2253 std::copy(src_t, src_t + src_size,
2254 dst->flat<int64>().data() + *dst_offset);
2255 break;
2256 }
2257 case DT_FLOAT: {
2258 auto src_t = src.flat<float>().data();
2259 std::copy(src_t, src_t + src_size,
2260 dst->flat<float>().data() + *dst_offset);
2261 break;
2262 }
2263 case DT_STRING: {
2264 auto src_t = src.flat<tstring>().data();
2265 std::copy(src_t, src_t + src_size,
2266 dst->flat<tstring>().data() + *dst_offset);
2267 break;
2268 }
2269 default:
2270 ReportUnexpectedDataType(dtype);
2271 }
2272 *dst_offset += src_size;
2273 }
2274
2275 // Parses dense features in `context_features`, and writes their parsed
2276 // values to `context_results`.
ParseContextDenseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2277 Status ParseContextDenseFeatures(const FeatureProtosMap& context_features,
2278 const FastParseExampleConfig& context_config,
2279 gtl::ArraySlice<tstring> example_names,
2280 bool is_batch, int num_examples,
2281 Allocator* allocator, Result* context_result) {
2282 for (int t = 0; t < context_config.dense.size(); ++t) {
2283 const auto& c = context_config.dense[t];
2284 const FeatureProtos& feature =
2285 context_features.find(c.feature_name)->second;
2286 TensorShape dense_shape, example_shape;
2287 DataType dtype = c.dtype;
2288 const size_t data_max_elements = feature.length;
2289 if (!c.shape.AsTensorShape(&example_shape) ||
2290 data_max_elements != example_shape.num_elements()) {
2291 return errors::InvalidArgument(
2292 "Inconsistent max number of elements for feature ", c.feature_name,
2293 ": expected ", example_shape.num_elements(), ", but found ",
2294 data_max_elements);
2295 }
2296 if (is_batch) {
2297 dense_shape.AddDim(num_examples);
2298 }
2299 for (const int dim : c.shape.dim_sizes()) {
2300 dense_shape.AddDim(dim);
2301 }
2302 context_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2303
2304 Tensor& out = context_result->dense_values[t];
2305 size_t out_offset = 0;
2306
2307 // Fill in the values.
2308 for (int e = 0; e < num_examples; e++) {
2309 size_t num_elements = 0;
2310 const auto& feature_proto = feature.protos[e];
2311 if (!feature.protos_present[e]) {
2312 // Copy the default value, if present. If not, return an error.
2313 if (c.default_value.NumElements() == 0) {
2314 return errors::InvalidArgument(
2315 "Feature: ", c.feature_name,
2316 " (data type: ", DataTypeString(c.dtype), ")",
2317 " is required but could not be found.");
2318 }
2319 CopyTensorIntoTensor(dtype, c.default_value, &out, &out_offset);
2320 num_elements += c.default_value.NumElements();
2321 } else if (!feature_proto.empty()) {
2322 protobuf::io::CodedInputStream stream(
2323 reinterpret_cast<const uint8*>(feature_proto.data()),
2324 feature_proto.size());
2325 EnableAliasing(&stream);
2326 num_elements += ParseFeature(dtype, &stream, &out, &out_offset);
2327 }
2328 if (num_elements != data_max_elements) {
2329 return errors::InvalidArgument(
2330 "Unexpected number of elements in example ",
2331 ExampleName(example_names, e));
2332 }
2333 }
2334 }
2335 return Status::OK();
2336 }
2337
2338 // Parses sparse features in `context_features`, and writes their parsed
2339 // values to `context_results`.
ParseContextSparseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2340 Status ParseContextSparseFeatures(const FeatureProtosMap& context_features,
2341 const FastParseExampleConfig& context_config,
2342 gtl::ArraySlice<tstring> example_names,
2343 bool is_batch, int num_examples,
2344 Allocator* allocator,
2345 Result* context_result) {
2346 for (int t = 0; t < context_config.sparse.size(); ++t) {
2347 const auto& c = context_config.sparse[t];
2348 const FeatureProtos& feature =
2349 context_features.find(c.feature_name)->second;
2350 TensorShape indices_shape, values_shape;
2351 DataType dtype = c.dtype;
2352 size_t expected_num_elements = feature.length;
2353 indices_shape.AddDim(expected_num_elements);
2354 indices_shape.AddDim(is_batch ? 2 : 1);
2355 values_shape.AddDim(expected_num_elements);
2356 context_result->sparse_indices[t] =
2357 Tensor(allocator, DT_INT64, indices_shape);
2358 context_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2359 context_result->sparse_shapes[t] =
2360 Tensor(allocator, DT_INT64, TensorShape({is_batch ? 2 : 1}));
2361 Tensor& out_values = context_result->sparse_values[t];
2362 size_t out_values_offset = 0;
2363 int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
2364 auto out_shape = context_result->sparse_shapes[t].vec<int64>();
2365
2366 // Fill in the values.
2367 size_t num_elements = 0;
2368 size_t max_num_cols = 0;
2369 for (int e = 0; e < num_examples; e++) {
2370 const auto& feature_proto = feature.protos[e];
2371 if (feature_proto.empty()) continue;
2372 protobuf::io::CodedInputStream stream(
2373 reinterpret_cast<const uint8*>(feature_proto.data()),
2374 feature_proto.size());
2375 EnableAliasing(&stream);
2376 size_t num_added =
2377 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2378 num_elements += num_added;
2379 max_num_cols = std::max(max_num_cols, num_added);
2380 for (int i = 0; i < num_added; i++) {
2381 if (is_batch) *out_indices++ = e;
2382 *out_indices++ = i;
2383 }
2384 }
2385 if (num_elements != expected_num_elements) {
2386 return errors::InvalidArgument(
2387 "Unexpected total number of elements in feature ", c.feature_name);
2388 }
2389 if (is_batch) {
2390 out_shape(0) = num_examples;
2391 out_shape(1) = max_num_cols;
2392 } else {
2393 out_shape(0) = max_num_cols;
2394 }
2395 }
2396 return Status::OK();
2397 }
2398
2399 // Parses ragged features in `context_features`, and writes their parsed
2400 // values to `context_results`.
ParseContextRaggedFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2401 Status ParseContextRaggedFeatures(const FeatureProtosMap& context_features,
2402 const FastParseExampleConfig& context_config,
2403 gtl::ArraySlice<tstring> example_names,
2404 bool is_batch, int num_examples,
2405 Allocator* allocator,
2406 Result* context_result) {
2407 for (int t = 0; t < context_config.ragged.size(); ++t) {
2408 const auto& c = context_config.ragged[t];
2409 const FeatureProtos& feature =
2410 context_features.find(c.feature_name)->second;
2411 TensorShape values_shape, splits_shape;
2412 DataType dtype = c.dtype;
2413 DataType splits_dtype = c.splits_dtype;
2414 size_t expected_num_elements = feature.length;
2415 values_shape.AddDim(expected_num_elements);
2416 if (is_batch) {
2417 splits_shape.AddDim(num_examples + 1);
2418 }
2419 context_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2420 context_result->ragged_splits[t] =
2421 Tensor(allocator, splits_dtype, splits_shape);
2422 Tensor& out_values = context_result->ragged_values[t];
2423 size_t out_values_offset = 0;
2424 int32* int32_splits =
2425 is_batch && splits_dtype == DT_INT32
2426 ? context_result->ragged_splits[t].vec<int32>().data()
2427 : nullptr;
2428 int64* int64_splits =
2429 is_batch && splits_dtype == DT_INT64
2430 ? context_result->ragged_splits[t].vec<int64>().data()
2431 : nullptr;
2432 if (int32_splits) {
2433 *int32_splits++ = 0;
2434 } else if (int64_splits) {
2435 *int64_splits++ = 0;
2436 }
2437
2438 // Fill in the values.
2439 size_t split = 0; // = total number of elements we've seen so far
2440 for (int e = 0; e < num_examples; e++) {
2441 const auto& feature_proto = feature.protos[e];
2442 if (!feature_proto.empty()) {
2443 protobuf::io::CodedInputStream stream(
2444 reinterpret_cast<const uint8*>(feature_proto.data()),
2445 feature_proto.size());
2446 EnableAliasing(&stream);
2447 size_t num_added =
2448 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2449 split += num_added;
2450 }
2451 if (int32_splits) {
2452 *int32_splits++ = split;
2453 } else if (int64_splits) {
2454 *int64_splits++ = split;
2455 }
2456 }
2457 if (split != expected_num_elements) {
2458 return errors::InvalidArgument(
2459 "Unexpected total number of elements in feature ", c.feature_name);
2460 }
2461 if (int32_splits || int64_splits) {
2462 int actual_splits =
2463 int32_splits
2464 ? int32_splits -
2465 context_result->ragged_splits[t].vec<int32>().data()
2466 : int64_splits -
2467 context_result->ragged_splits[t].vec<int64>().data();
2468 if (actual_splits != num_examples + 1) {
2469 return errors::InvalidArgument(
2470 "Unexpected number of examples for feature ", c.feature_name);
2471 }
2472 }
2473 }
2474 return Status::OK();
2475 }
2476
2477 // Parses dense features in `sequence_features`, and writes their parsed
2478 // values to `sequence_result`.
ParseSequenceDenseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths)2479 Status ParseSequenceDenseFeatures(const FeatureProtosMap& sequence_features,
2480 const FastParseExampleConfig& sequence_config,
2481 gtl::ArraySlice<tstring> example_names,
2482 bool is_batch, int num_examples,
2483 Allocator* allocator, Result* sequence_result,
2484 std::vector<Tensor>* dense_feature_lengths) {
2485 TensorShape dense_length_shape;
2486 if (is_batch) {
2487 dense_length_shape.AddDim(num_examples);
2488 }
2489 for (int t = 0; t < sequence_config.dense.size(); ++t) {
2490 const auto& c = sequence_config.dense[t];
2491 const FeatureProtos& feature =
2492 sequence_features.find(c.feature_name)->second;
2493 TensorShape dense_shape, row_shape;
2494 DataType dtype = c.dtype;
2495 const size_t expected_max_elements = feature.length;
2496 if (!c.shape.AsTensorShape(&row_shape) ||
2497 expected_max_elements !=
2498 (expected_max_elements / row_shape.num_elements()) *
2499 row_shape.num_elements()) {
2500 PartialTensorShape total_shape = row_shape;
2501 total_shape.InsertDim(0, -1);
2502 return errors::InvalidArgument(
2503 "Feature list '", c.feature_name,
2504 "' has an unexpected number of values. Total values size: ",
2505 expected_max_elements,
2506 " is not consistent with output shape: ", total_shape.DebugString());
2507 }
2508 int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
2509 if (is_batch) {
2510 dense_shape.AddDim(num_examples);
2511 }
2512 dense_shape.AddDim(expected_max_rows);
2513 for (const int dim : sequence_config.dense[t].shape.dim_sizes()) {
2514 dense_shape.AddDim(dim);
2515 }
2516 sequence_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2517 (*dense_feature_lengths)[t] =
2518 Tensor(allocator, DT_INT64, dense_length_shape);
2519 int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
2520
2521 tstring* out_bytes = nullptr;
2522 float* out_float = nullptr;
2523 int64* out_int64 = nullptr;
2524 switch (dtype) {
2525 case DT_STRING:
2526 out_bytes = sequence_result->dense_values[t].flat<tstring>().data();
2527 break;
2528 case DT_FLOAT:
2529 out_float = sequence_result->dense_values[t].flat<float>().data();
2530 break;
2531 case DT_INT64:
2532 out_int64 = sequence_result->dense_values[t].flat<int64>().data();
2533 break;
2534 default:
2535 ReportUnexpectedDataType(dtype);
2536 }
2537
2538 // Fill in the values.
2539 for (int e = 0; e < num_examples; e++) {
2540 size_t num_elements = 0, num_rows = 0;
2541 const auto& feature_proto = feature.protos[e];
2542 if (!feature.protos_present[e]) {
2543 // Return an error if this feature was not allowed to be missing.
2544 // Otherwise, we'll pad as needed below.
2545 if (!c.variable_length) {
2546 return errors::InvalidArgument(
2547 "Name: ", ExampleName(example_names, e), ", Feature list '",
2548 c.feature_name,
2549 "' is required but could not be found. "
2550 "Did you mean to include it in "
2551 "feature_list_dense_missing_assumed_empty or "
2552 "feature_list_dense_defaults?");
2553 }
2554 } else if (!feature_proto.empty()) {
2555 protobuf::io::CodedInputStream stream(
2556 reinterpret_cast<const uint8*>(feature_proto.data()),
2557 feature_proto.size());
2558 EnableAliasing(&stream);
2559 while (!stream.ExpectAtEnd()) {
2560 uint32 feature_length;
2561 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2562 !stream.ReadVarint32(&feature_length)) {
2563 return errors::InvalidArgument("Error in sequence feature ",
2564 c.feature_name, " in example ",
2565 ExampleName(example_names, e));
2566 }
2567 auto limit = stream.PushLimit(feature_length);
2568 int num_added = 0;
2569 if (feature_length > 2) {
2570 switch (dtype) {
2571 case DT_STRING:
2572 num_added = ParseBytesFeature(&stream, out_bytes);
2573 out_bytes += num_added;
2574 break;
2575 case DT_FLOAT:
2576 num_added = ParseFloatFeature(&stream, out_float);
2577 out_float += num_added;
2578 break;
2579 case DT_INT64:
2580 num_added = ParseInt64Feature(&stream, out_int64);
2581 out_int64 += num_added;
2582 break;
2583 default:
2584 ReportUnexpectedDataType(dtype);
2585 num_added = 0;
2586 }
2587 if (num_added < 0) {
2588 // This should be unreachable -- we already scanned the feature in
2589 // GetSequenceFeatureLengths, and it hasn't changed since then.
2590 return errors::InvalidArgument("Error in sequence feature ",
2591 c.feature_name, " in example ",
2592 ExampleName(example_names, e));
2593 }
2594 }
2595 if (num_added != row_shape.num_elements()) {
2596 return errors::InvalidArgument(
2597 "Name: ", ExampleName(example_names, e),
2598 ", Key: ", c.feature_name, ", Index: ", num_rows,
2599 ". Number of values != expected. values size: ", num_added,
2600 " but output shape: ", row_shape.DebugString());
2601 }
2602 num_elements += num_added;
2603 num_rows++;
2604 stream.PopLimit(limit);
2605 }
2606 }
2607 *out_lengths++ = num_rows;
2608 // Pad as necessary.
2609 int num_to_pad = expected_max_elements - num_elements;
2610 switch (dtype) {
2611 case DT_STRING:
2612 out_bytes += num_to_pad;
2613 break;
2614 case DT_FLOAT:
2615 PadFloatFeature(num_to_pad, out_float);
2616 out_float += num_to_pad;
2617 break;
2618 case DT_INT64:
2619 PadInt64Feature(num_to_pad, out_int64);
2620 out_int64 += num_to_pad;
2621 break;
2622 default:
2623 ReportUnexpectedDataType(dtype);
2624 }
2625 }
2626 }
2627 return Status::OK();
2628 }
2629
2630 // Parses sparse features in `sequence_features`, and writes their parsed
2631 // values to `sequence_result`.
ParseSequenceSparseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2632 Status ParseSequenceSparseFeatures(
2633 const FeatureProtosMap& sequence_features,
2634 const FastParseExampleConfig& sequence_config,
2635 gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2636 Allocator* allocator, Result* sequence_result) {
2637 for (int t = 0; t < sequence_config.sparse.size(); ++t) {
2638 const auto& c = sequence_config.sparse[t];
2639 const FeatureProtos& feature =
2640 sequence_features.find(c.feature_name)->second;
2641 TensorShape indices_shape, values_shape;
2642 DataType dtype = c.dtype;
2643 size_t expected_num_elements = feature.length;
2644 indices_shape.AddDim(expected_num_elements);
2645 indices_shape.AddDim(is_batch ? 3 : 2);
2646 values_shape.AddDim(expected_num_elements);
2647 sequence_result->sparse_indices[t] =
2648 Tensor(allocator, DT_INT64, indices_shape);
2649 sequence_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2650 sequence_result->sparse_shapes[t] =
2651 Tensor(allocator, DT_INT64, TensorShape({is_batch ? 3 : 2}));
2652
2653 tstring* out_bytes = nullptr;
2654 float* out_float = nullptr;
2655 int64* out_int64 = nullptr;
2656 switch (dtype) {
2657 case DT_STRING:
2658 out_bytes = sequence_result->sparse_values[t].flat<tstring>().data();
2659 break;
2660 case DT_FLOAT:
2661 out_float = sequence_result->sparse_values[t].flat<float>().data();
2662 break;
2663 case DT_INT64:
2664 out_int64 = sequence_result->sparse_values[t].flat<int64>().data();
2665 break;
2666 default:
2667 ReportUnexpectedDataType(dtype);
2668 }
2669 int64* out_indices =
2670 sequence_result->sparse_indices[t].flat<int64>().data();
2671 auto out_shape = sequence_result->sparse_shapes[t].vec<int64>();
2672
2673 // Fill in the values.
2674 size_t num_elements = 0;
2675 size_t max_num_rows = 0;
2676 size_t max_num_cols = 0;
2677 for (int e = 0; e < num_examples; e++) {
2678 const auto& feature_proto = feature.protos[e];
2679 if (feature_proto.empty()) continue;
2680 protobuf::io::CodedInputStream stream(
2681 reinterpret_cast<const uint8*>(feature_proto.data()),
2682 feature_proto.size());
2683 EnableAliasing(&stream);
2684 size_t num_rows = 0;
2685 while (!stream.ExpectAtEnd()) {
2686 uint32 feature_length;
2687 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2688 !stream.ReadVarint32(&feature_length)) {
2689 // This should be unreachable -- we already scanned the feature in
2690 // GetSequenceFeatureLengths, and it hasn't changed since then.
2691 return errors::InvalidArgument("Error in sequence feature ",
2692 c.feature_name, " in example ",
2693 ExampleName(example_names, e));
2694 }
2695 if (feature_length > 2) {
2696 auto limit = stream.PushLimit(feature_length);
2697 size_t num_added;
2698 switch (dtype) {
2699 case DT_STRING:
2700 num_added = ParseBytesFeature(&stream, out_bytes);
2701 out_bytes += num_added;
2702 break;
2703 case DT_FLOAT:
2704 num_added = ParseFloatFeature(&stream, out_float);
2705 out_float += num_added;
2706 break;
2707 case DT_INT64:
2708 num_added = ParseInt64Feature(&stream, out_int64);
2709 out_int64 += num_added;
2710 break;
2711 default:
2712 ReportUnexpectedDataType(dtype);
2713 num_added = 0;
2714 }
2715 num_elements += num_added;
2716 max_num_cols = std::max(max_num_cols, num_added);
2717 for (int i = 0; i < num_added; i++) {
2718 if (is_batch) *out_indices++ = e;
2719 *out_indices++ = num_rows;
2720 *out_indices++ = i;
2721 }
2722 stream.PopLimit(limit);
2723 } else if (feature_length == 2) {
2724 if (!SkipEmptyFeature(&stream, dtype)) {
2725 // This should be unreachable -- we already scanned the feature in
2726 // GetSequenceFeatureLengths, and it hasn't changed since then.
2727 return errors::InvalidArgument("Error in sequence feature ",
2728 c.feature_name, " in example ",
2729 ExampleName(example_names, e));
2730 }
2731 } else if (feature_length != 0) {
2732 // This should be unreachable -- we already scanned the feature in
2733 // GetSequenceFeatureLengths, and it hasn't changed since then.
2734 return errors::InvalidArgument("Error in sequence feature ",
2735 c.feature_name, " in example ",
2736 ExampleName(example_names, e));
2737 }
2738 num_rows++;
2739 }
2740 max_num_rows = std::max(max_num_rows, num_rows);
2741 }
2742 if (num_elements != expected_num_elements) {
2743 return errors::InvalidArgument(
2744 "Unexpected number of elements in feature ", c.feature_name);
2745 }
2746 if (is_batch) {
2747 out_shape(0) = num_examples;
2748 out_shape(1) = max_num_rows;
2749 out_shape(2) = max_num_cols;
2750 } else {
2751 out_shape(0) = max_num_rows;
2752 out_shape(1) = max_num_cols;
2753 }
2754 }
2755 return Status::OK();
2756 }
2757
2758 // Parses ragged features in `sequence_features`, and writes their parsed
2759 // values to `sequence_result`.
ParseSequenceRaggedFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2760 Status ParseSequenceRaggedFeatures(
2761 const FeatureProtosMap& sequence_features,
2762 const FastParseExampleConfig& sequence_config,
2763 gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2764 Allocator* allocator, Result* sequence_result) {
2765 for (int t = 0; t < sequence_config.ragged.size(); ++t) {
2766 const auto& c = sequence_config.ragged[t];
2767 const FeatureProtos& feature =
2768 sequence_features.find(c.feature_name)->second;
2769 TensorShape values_shape, inner_splits_shape, outer_splits_shape;
2770 DataType dtype = c.dtype;
2771 DataType splits_dtype = c.splits_dtype;
2772 size_t expected_num_elements = feature.length;
2773 size_t expected_num_rows = feature.num_rows;
2774 values_shape.AddDim(expected_num_elements);
2775 inner_splits_shape.AddDim(expected_num_rows + 1);
2776 if (is_batch) {
2777 outer_splits_shape.AddDim(num_examples + 1);
2778 }
2779 sequence_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2780 sequence_result->ragged_splits[t] =
2781 Tensor(allocator, splits_dtype, inner_splits_shape);
2782 sequence_result->ragged_outer_splits[t] =
2783 Tensor(allocator, splits_dtype, outer_splits_shape);
2784 Tensor& out_values = sequence_result->ragged_values[t];
2785 size_t out_values_offset = 0;
2786 int32* int32_inner_splits =
2787 splits_dtype == DT_INT32
2788 ? sequence_result->ragged_splits[t].vec<int32>().data()
2789 : nullptr;
2790 int64* int64_inner_splits =
2791 splits_dtype == DT_INT64
2792 ? sequence_result->ragged_splits[t].vec<int64>().data()
2793 : nullptr;
2794 int32* int32_outer_splits =
2795 is_batch && splits_dtype == DT_INT32
2796 ? sequence_result->ragged_outer_splits[t].vec<int32>().data()
2797 : nullptr;
2798 int64* int64_outer_splits =
2799 is_batch && splits_dtype == DT_INT64
2800 ? sequence_result->ragged_outer_splits[t].vec<int64>().data()
2801 : nullptr;
2802 if (int32_inner_splits) {
2803 *int32_inner_splits++ = 0;
2804 } else if (int64_inner_splits) {
2805 *int64_inner_splits++ = 0;
2806 }
2807 if (int32_outer_splits) {
2808 *int32_outer_splits++ = 0;
2809 } else if (int64_outer_splits) {
2810 *int64_outer_splits++ = 0;
2811 }
2812
2813 // Fill in the values.
2814 size_t inner_split = 0; // total number of elements we've seen so far
2815 size_t outer_split = 0; // total number of rows we've seen so far
2816 for (int e = 0; e < num_examples; e++) {
2817 const auto& feature_proto = feature.protos[e];
2818 if (!feature_proto.empty()) {
2819 protobuf::io::CodedInputStream stream(
2820 reinterpret_cast<const uint8*>(feature_proto.data()),
2821 feature_proto.size());
2822 EnableAliasing(&stream);
2823 while (!stream.ExpectAtEnd()) {
2824 uint32 feature_length;
2825 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2826 !stream.ReadVarint32(&feature_length)) {
2827 // This should be unreachable -- we already scanned the feature in
2828 // GetSequenceFeatureLengths, and it hasn't changed since then.
2829 return errors::InvalidArgument("Error in sequence feature ",
2830 c.feature_name, " in example ",
2831 ExampleName(example_names, e));
2832 }
2833 if (feature_length > 2) {
2834 auto limit = stream.PushLimit(feature_length);
2835 size_t num_added =
2836 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2837 inner_split += num_added;
2838 stream.PopLimit(limit);
2839 } else if (feature_length == 2) {
2840 if (!SkipEmptyFeature(&stream, dtype)) {
2841 // This should be unreachable -- we already scanned the feature in
2842 // GetSequenceFeatureLengths, and it hasn't changed since then.
2843 return errors::InvalidArgument("Error in sequence feature ",
2844 c.feature_name, " in example ",
2845 ExampleName(example_names, e));
2846 }
2847 } else if (feature_length != 0) {
2848 // This should be unreachable -- we already scanned the feature in
2849 // GetSequenceFeatureLengths, and it hasn't changed since then.
2850 return errors::InvalidArgument("Error in sequence feature ",
2851 c.feature_name, " in example ",
2852 ExampleName(example_names, e));
2853 }
2854 if (int32_inner_splits) {
2855 *int32_inner_splits++ = inner_split;
2856 } else if (int64_inner_splits) {
2857 *int64_inner_splits++ = inner_split;
2858 }
2859 outer_split++;
2860 }
2861 }
2862 if (int32_outer_splits) {
2863 *int32_outer_splits++ = outer_split;
2864 } else if (int64_outer_splits) {
2865 *int64_outer_splits++ = outer_split;
2866 }
2867 }
2868 if (outer_split != expected_num_rows) {
2869 return errors::InvalidArgument("Unexpected number of rows for feature ",
2870 c.feature_name);
2871 }
2872 if (inner_split != expected_num_elements) {
2873 return errors::InvalidArgument(
2874 "Unexpected number of elements for feature ", c.feature_name);
2875 }
2876
2877 if (int32_inner_splits || int64_inner_splits) {
2878 const auto& inner_splits = sequence_result->ragged_splits[t];
2879 int num_inner_splits =
2880 int32_inner_splits
2881 ? int32_inner_splits - inner_splits.vec<int32>().data()
2882 : int64_inner_splits - inner_splits.vec<int64>().data();
2883 if (num_inner_splits != expected_num_rows + 1) {
2884 return errors::InvalidArgument("Unexpected number of rows for feature ",
2885 c.feature_name);
2886 }
2887 }
2888 if (int32_outer_splits || int64_outer_splits) {
2889 const auto& outer_splits = sequence_result->ragged_outer_splits[t];
2890 int num_outer_splits =
2891 int32_outer_splits
2892 ? int32_outer_splits - outer_splits.vec<int32>().data()
2893 : int64_outer_splits - outer_splits.vec<int64>().data();
2894 if (num_outer_splits != num_examples + 1) {
2895 return errors::InvalidArgument(
2896 "Unexpected number of examples for feature ", c.feature_name);
2897 }
2898 }
2899 }
2900 return Status::OK();
2901 }
2902
2903 } // namespace
2904
2905 // TODO(sundberg): Use the threadpool to parallelize example parsing.
2906 // TODO(b/111553342): Support extracting feature statistics from the examples.
FastParseSequenceExample(const FastParseExampleConfig & context_config,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * context_result,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths,bool is_batch)2907 Status FastParseSequenceExample(const FastParseExampleConfig& context_config,
2908 const FastParseExampleConfig& sequence_config,
2909 gtl::ArraySlice<tstring> serialized,
2910 gtl::ArraySlice<tstring> example_names,
2911 thread::ThreadPool* thread_pool,
2912 Result* context_result, Result* sequence_result,
2913 std::vector<Tensor>* dense_feature_lengths,
2914 bool is_batch) {
2915 int num_examples = serialized.size();
2916 DCHECK(context_result != nullptr);
2917 DCHECK(sequence_result != nullptr);
2918 DCHECK(dense_feature_lengths != nullptr);
2919 size_t num_context_features = context_config.sparse.size() +
2920 context_config.dense.size() +
2921 context_config.ragged.size();
2922 FeatureProtosMap context_features;
2923 context_features.reserve(num_context_features);
2924
2925 if (!example_names.empty() && example_names.size() != num_examples) {
2926 return errors::InvalidArgument(
2927 "example_names must be empty or have the correct number of elements");
2928 }
2929 for (auto& c : context_config.sparse) {
2930 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2931 FeatureProtos& feature = context_features[c.feature_name];
2932 feature.dtype = c.dtype;
2933 feature.length = 0;
2934 feature.type = Type::Sparse;
2935 feature.protos.resize(num_examples);
2936 feature.protos_present.resize(num_examples);
2937 }
2938 for (auto& c : context_config.ragged) {
2939 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2940 FeatureProtos& feature = context_features[c.feature_name];
2941 if (feature.type == Type::Sparse) {
2942 return errors::InvalidArgument("Context feature " + c.feature_name +
2943 " cannot be both ragged and sparse");
2944 }
2945 feature.dtype = c.dtype;
2946 feature.length = 0;
2947 feature.type = Type::Ragged;
2948 feature.protos.resize(num_examples);
2949 feature.protos_present.resize(num_examples);
2950 }
2951 for (auto& c : context_config.dense) {
2952 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2953 FeatureProtos& feature = context_features[c.feature_name];
2954 if (feature.type != Type::Dense) {
2955 return errors::InvalidArgument("Context feature " + c.feature_name +
2956 " cannot be both dense and sparse");
2957 }
2958 if (c.default_value.NumElements() > 0) {
2959 if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
2960 return errors::InvalidArgument("Default value for context feature ",
2961 c.feature_name,
2962 " has an incorrect shape: saw ",
2963 c.default_value.shape().DebugString(),
2964 " but expected ", c.shape.DebugString());
2965 }
2966 }
2967 feature.dtype = c.dtype;
2968 feature.length = c.default_value.NumElements();
2969 feature.protos.resize(num_examples);
2970 feature.protos_present.resize(num_examples);
2971 }
2972 size_t num_sequence_features = sequence_config.sparse.size() +
2973 sequence_config.dense.size() +
2974 sequence_config.ragged.size();
2975 FeatureProtosMap sequence_features;
2976 sequence_features.reserve(num_sequence_features);
2977 for (auto& c : sequence_config.sparse) {
2978 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2979 FeatureProtos& feature = sequence_features[c.feature_name];
2980 feature.dtype = c.dtype;
2981 feature.length = 0;
2982 feature.type = Type::Sparse;
2983 feature.protos.resize(num_examples);
2984 feature.protos_present.resize(num_examples);
2985 }
2986 for (auto& c : sequence_config.ragged) {
2987 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2988 FeatureProtos& feature = sequence_features[c.feature_name];
2989 if (feature.type == Type::Sparse) {
2990 return errors::InvalidArgument("Sequence feature " + c.feature_name +
2991 " cannot be both ragged and sparse");
2992 }
2993 feature.dtype = c.dtype;
2994 feature.length = 0;
2995 feature.type = Type::Ragged;
2996 feature.protos.resize(num_examples);
2997 feature.protos_present.resize(num_examples);
2998 }
2999 for (auto& c : sequence_config.dense) {
3000 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
3001 FeatureProtos& feature = sequence_features[c.feature_name];
3002 if (feature.type != Type::Dense) {
3003 return errors::InvalidArgument("Sequence feature " + c.feature_name +
3004 " cannot be both dense and sparse");
3005 }
3006 feature.dtype = c.dtype;
3007 feature.length = 0;
3008 feature.protos.resize(num_examples);
3009 feature.protos_present.resize(num_examples);
3010 }
3011
3012 // Find the serialized proto substrings for each feature.
3013 TF_RETURN_IF_ERROR(ExtractFeaturesFromSequenceExamples(
3014 serialized, example_names, &context_features, &sequence_features));
3015
3016 // Scan through the protos to determine how much memory we need to allocate.
3017 TF_RETURN_IF_ERROR(
3018 GetContextFeatureLengths(example_names, &context_features));
3019 TF_RETURN_IF_ERROR(
3020 GetSequenceFeatureLengths(example_names, &sequence_features));
3021
3022 // Allocate memory.
3023 context_result->sparse_values.resize(context_config.sparse.size());
3024 context_result->sparse_indices.resize(context_config.sparse.size());
3025 context_result->sparse_shapes.resize(context_config.sparse.size());
3026 context_result->dense_values.resize(context_config.dense.size());
3027 context_result->ragged_values.resize(context_config.ragged.size());
3028 context_result->ragged_splits.resize(context_config.ragged.size());
3029 context_result->ragged_outer_splits.resize(context_config.ragged.size());
3030 sequence_result->sparse_values.resize(sequence_config.sparse.size());
3031 sequence_result->sparse_indices.resize(sequence_config.sparse.size());
3032 sequence_result->sparse_shapes.resize(sequence_config.sparse.size());
3033 sequence_result->dense_values.resize(sequence_config.dense.size());
3034 sequence_result->ragged_values.resize(sequence_config.ragged.size());
3035 sequence_result->ragged_splits.resize(sequence_config.ragged.size());
3036 sequence_result->ragged_outer_splits.resize(sequence_config.ragged.size());
3037 dense_feature_lengths->resize(sequence_config.dense.size());
3038
3039 // NOTE(mrry): Cache the CPU allocator here and use it in Tensor construction,
3040 // to avoid lock contention in `tensorflow::cpu_allocator()`.
3041 Allocator* allocator = tensorflow::cpu_allocator();
3042
3043 TF_RETURN_IF_ERROR(ParseContextDenseFeatures(
3044 context_features, context_config, example_names, is_batch, num_examples,
3045 allocator, context_result));
3046 TF_RETURN_IF_ERROR(ParseContextSparseFeatures(
3047 context_features, context_config, example_names, is_batch, num_examples,
3048 allocator, context_result));
3049 TF_RETURN_IF_ERROR(ParseContextRaggedFeatures(
3050 context_features, context_config, example_names, is_batch, num_examples,
3051 allocator, context_result));
3052 TF_RETURN_IF_ERROR(ParseSequenceDenseFeatures(
3053 sequence_features, sequence_config, example_names, is_batch, num_examples,
3054 allocator, sequence_result, dense_feature_lengths));
3055 TF_RETURN_IF_ERROR(ParseSequenceSparseFeatures(
3056 sequence_features, sequence_config, example_names, is_batch, num_examples,
3057 allocator, sequence_result));
3058 TF_RETURN_IF_ERROR(ParseSequenceRaggedFeatures(
3059 sequence_features, sequence_config, example_names, is_batch, num_examples,
3060 allocator, sequence_result));
3061
3062 return Status::OK();
3063 }
3064
3065 } // namespace example
3066 } // namespace tensorflow
3067