1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "tensorflow/core/util/example_proto_fast_parsing.h"
19
20 #include "tensorflow/core/example/example.pb.h"
21 #include "tensorflow/core/example/feature.pb.h"
22 #include "tensorflow/core/lib/random/philox_random.h"
23 #include "tensorflow/core/lib/random/simple_philox.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/util/example_proto_fast_parsing_test.pb.h"
28
29 namespace tensorflow {
30 namespace example {
31 namespace {
32
33 constexpr char kDenseInt64Key[] = "dense_int64";
34 constexpr char kDenseFloatKey[] = "dense_float";
35 constexpr char kDenseStringKey[] = "dense_string";
36
37 constexpr char kSparseInt64Key[] = "sparse_int64";
38 constexpr char kSparseFloatKey[] = "sparse_float";
39 constexpr char kSparseStringKey[] = "sparse_string";
40
SerializedToReadable(string serialized)41 string SerializedToReadable(string serialized) {
42 string result;
43 result += '"';
44 for (char c : serialized)
45 result += strings::StrCat("\\x", strings::Hex(c, strings::kZeroPad2));
46 result += '"';
47 return result;
48 }
49
50 template <class T>
Serialize(const T & example)51 string Serialize(const T& example) {
52 string serialized;
53 example.SerializeToString(&serialized);
54 return serialized;
55 }
56
57 // Tests that serialized gets parsed identically by TestFastParse(..)
58 // and the regular Example.ParseFromString(..).
TestCorrectness(const string & serialized)59 void TestCorrectness(const string& serialized) {
60 Example example;
61 Example fast_example;
62 EXPECT_TRUE(example.ParseFromString(serialized));
63 example.DiscardUnknownFields();
64 EXPECT_TRUE(TestFastParse(serialized, &fast_example));
65 EXPECT_EQ(example.DebugString(), fast_example.DebugString());
66 if (example.DebugString() != fast_example.DebugString()) {
67 LOG(ERROR) << "Bad serialized: " << SerializedToReadable(serialized);
68 }
69 }
70
71 // Fast parsing does not differentiate between EmptyExample and EmptyFeatures
72 // TEST(FastParse, EmptyExample) {
73 // Example example;
74 // TestCorrectness(example);
75 // }
76
TEST(FastParse,IgnoresPrecedingUnknownTopLevelFields)77 TEST(FastParse, IgnoresPrecedingUnknownTopLevelFields) {
78 ExampleWithExtras example;
79 (*example.mutable_features()->mutable_feature())["age"]
80 .mutable_int64_list()
81 ->add_value(13);
82 example.set_extra1("some_str");
83 example.set_extra2(123);
84 example.set_extra3(234);
85 example.set_extra4(345);
86 example.set_extra5(4.56);
87 example.add_extra6(5.67);
88 example.add_extra6(6.78);
89 (*example.mutable_extra7()->mutable_feature())["extra7"]
90 .mutable_int64_list()
91 ->add_value(1337);
92
93 Example context;
94 (*context.mutable_features()->mutable_feature())["zipcode"]
95 .mutable_int64_list()
96 ->add_value(94043);
97
98 TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
99 }
100
TEST(FastParse,IgnoresTrailingUnknownTopLevelFields)101 TEST(FastParse, IgnoresTrailingUnknownTopLevelFields) {
102 Example example;
103 (*example.mutable_features()->mutable_feature())["age"]
104 .mutable_int64_list()
105 ->add_value(13);
106
107 ExampleWithExtras context;
108 (*context.mutable_features()->mutable_feature())["zipcode"]
109 .mutable_int64_list()
110 ->add_value(94043);
111 context.set_extra1("some_str");
112 context.set_extra2(123);
113 context.set_extra3(234);
114 context.set_extra4(345);
115 context.set_extra5(4.56);
116 context.add_extra6(5.67);
117 context.add_extra6(6.78);
118 (*context.mutable_extra7()->mutable_feature())["extra7"]
119 .mutable_int64_list()
120 ->add_value(1337);
121
122 TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
123 }
124
TEST(FastParse,SingleInt64WithContext)125 TEST(FastParse, SingleInt64WithContext) {
126 Example example;
127 (*example.mutable_features()->mutable_feature())["age"]
128 .mutable_int64_list()
129 ->add_value(13);
130
131 Example context;
132 (*context.mutable_features()->mutable_feature())["zipcode"]
133 .mutable_int64_list()
134 ->add_value(94043);
135
136 TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
137 }
138
TEST(FastParse,DenseInt64WithContext)139 TEST(FastParse, DenseInt64WithContext) {
140 Example example;
141 (*example.mutable_features()->mutable_feature())["age"]
142 .mutable_int64_list()
143 ->add_value(0);
144
145 Example context;
146 (*context.mutable_features()->mutable_feature())["age"]
147 .mutable_int64_list()
148 ->add_value(15);
149
150 string serialized = Serialize(example) + Serialize(context);
151
152 {
153 Example deserialized;
154 EXPECT_TRUE(deserialized.ParseFromString(serialized));
155 EXPECT_EQ(deserialized.DebugString(), context.DebugString());
156 // Whoa! Last EQ is very surprising, but standard deserialization is what it
157 // is and Servo team requested to replicate this 'feature'.
158 // In future we should return error.
159 }
160 TestCorrectness(serialized);
161 }
162
TEST(FastParse,NonPacked)163 TEST(FastParse, NonPacked) {
164 TestCorrectness(
165 "\x0a\x0e\x0a\x0c\x0a\x03\x61\x67\x65\x12\x05\x1a\x03\x0a\x01\x0d");
166 }
167
TEST(FastParse,Packed)168 TEST(FastParse, Packed) {
169 TestCorrectness(
170 "\x0a\x0d\x0a\x0b\x0a\x03\x61\x67\x65\x12\x04\x1a\x02\x08\x0d");
171 }
172
TEST(FastParse,EmptyFeatures)173 TEST(FastParse, EmptyFeatures) {
174 Example example;
175 example.mutable_features();
176 TestCorrectness(Serialize(example));
177 }
178
TestCorrectnessJson(const string & json)179 void TestCorrectnessJson(const string& json) {
180 auto resolver = protobuf::util::NewTypeResolverForDescriptorPool(
181 "type.googleapis.com", protobuf::DescriptorPool::generated_pool());
182 string serialized;
183 auto s = protobuf::util::JsonToBinaryString(
184 resolver, "type.googleapis.com/tensorflow.Example", json, &serialized);
185 EXPECT_TRUE(s.ok()) << s;
186 delete resolver;
187 TestCorrectness(serialized);
188 }
189
TEST(FastParse,JsonUnivalent)190 TEST(FastParse, JsonUnivalent) {
191 TestCorrectnessJson(
192 "{'features': {"
193 " 'feature': {'age': {'int64_list': {'value': [0]} }}, "
194 " 'feature': {'flo': {'float_list': {'value': [1.1]} }}, "
195 " 'feature': {'byt': {'bytes_list': {'value': ['WW8='] }}}"
196 "}}");
197 }
198
TEST(FastParse,JsonMultivalent)199 TEST(FastParse, JsonMultivalent) {
200 TestCorrectnessJson(
201 "{'features': {"
202 " 'feature': {'age': {'int64_list': {'value': [0, 13, 23]} }}, "
203 " 'feature': {'flo': {'float_list': {'value': [1.1, 1.2, 1.3]} }}, "
204 " 'feature': {'byt': {'bytes_list': {'value': ['WW8=', 'WW8K'] }}}"
205 "}}");
206 }
207
TEST(FastParse,SingleInt64)208 TEST(FastParse, SingleInt64) {
209 Example example;
210 (*example.mutable_features()->mutable_feature())["age"]
211 .mutable_int64_list()
212 ->add_value(13);
213 TestCorrectness(Serialize(example));
214 }
215
ExampleWithSomeFeatures()216 static string ExampleWithSomeFeatures() {
217 Example example;
218
219 (*example.mutable_features()->mutable_feature())[""];
220
221 (*example.mutable_features()->mutable_feature())["empty_bytes_list"]
222 .mutable_bytes_list();
223 (*example.mutable_features()->mutable_feature())["empty_float_list"]
224 .mutable_float_list();
225 (*example.mutable_features()->mutable_feature())["empty_int64_list"]
226 .mutable_int64_list();
227
228 BytesList* bytes_list =
229 (*example.mutable_features()->mutable_feature())["bytes_list"]
230 .mutable_bytes_list();
231 bytes_list->add_value("bytes1");
232 bytes_list->add_value("bytes2");
233
234 FloatList* float_list =
235 (*example.mutable_features()->mutable_feature())["float_list"]
236 .mutable_float_list();
237 float_list->add_value(1.0);
238 float_list->add_value(2.0);
239
240 Int64List* int64_list =
241 (*example.mutable_features()->mutable_feature())["int64_list"]
242 .mutable_int64_list();
243 int64_list->add_value(3);
244 int64_list->add_value(270);
245 int64_list->add_value(86942);
246
247 return Serialize(example);
248 }
249
TEST(FastParse,SomeFeatures)250 TEST(FastParse, SomeFeatures) { TestCorrectness(ExampleWithSomeFeatures()); }
251
AddDenseFeature(const char * feature_name,DataType dtype,PartialTensorShape shape,bool variable_length,size_t elements_per_stride,FastParseExampleConfig * out_config)252 static void AddDenseFeature(const char* feature_name, DataType dtype,
253 PartialTensorShape shape, bool variable_length,
254 size_t elements_per_stride,
255 FastParseExampleConfig* out_config) {
256 out_config->dense.emplace_back();
257 auto& new_feature = out_config->dense.back();
258 new_feature.feature_name = feature_name;
259 new_feature.dtype = dtype;
260 new_feature.shape = std::move(shape);
261 new_feature.default_value = Tensor(dtype, {});
262 new_feature.variable_length = variable_length;
263 new_feature.elements_per_stride = elements_per_stride;
264 }
265
AddSparseFeature(const char * feature_name,DataType dtype,FastParseExampleConfig * out_config)266 static void AddSparseFeature(const char* feature_name, DataType dtype,
267 FastParseExampleConfig* out_config) {
268 out_config->sparse.emplace_back();
269 auto& new_feature = out_config->sparse.back();
270 new_feature.feature_name = feature_name;
271 new_feature.dtype = dtype;
272 }
273
TEST(FastParse,StatsCollection)274 TEST(FastParse, StatsCollection) {
275 const size_t kNumExamples = 13;
276 std::vector<tstring> serialized(kNumExamples, ExampleWithSomeFeatures());
277
278 FastParseExampleConfig config_dense;
279 AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense);
280 AddDenseFeature("float_list", DT_FLOAT, {2}, false, 2, &config_dense);
281 AddDenseFeature("int64_list", DT_INT64, {3}, false, 3, &config_dense);
282 config_dense.collect_feature_stats = true;
283
284 FastParseExampleConfig config_varlen;
285 AddDenseFeature("bytes_list", DT_STRING, {-1}, true, 1, &config_varlen);
286 AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_varlen);
287 AddDenseFeature("int64_list", DT_INT64, {-1}, true, 1, &config_varlen);
288 config_varlen.collect_feature_stats = true;
289
290 FastParseExampleConfig config_sparse;
291 AddSparseFeature("bytes_list", DT_STRING, &config_sparse);
292 AddSparseFeature("float_list", DT_FLOAT, &config_sparse);
293 AddSparseFeature("int64_list", DT_INT64, &config_sparse);
294 config_sparse.collect_feature_stats = true;
295
296 FastParseExampleConfig config_mixed;
297 AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_mixed);
298 AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_mixed);
299 AddSparseFeature("int64_list", DT_INT64, &config_mixed);
300 config_mixed.collect_feature_stats = true;
301
302 for (const FastParseExampleConfig& config :
303 {config_dense, config_varlen, config_sparse, config_mixed}) {
304 {
305 Result result;
306 TF_CHECK_OK(FastParseExample(config, serialized, {}, nullptr, &result));
307 EXPECT_EQ(kNumExamples, result.feature_stats.size());
308 for (const PerExampleFeatureStats& stats : result.feature_stats) {
309 EXPECT_EQ(7, stats.features_count);
310 EXPECT_EQ(7, stats.feature_values_count);
311 }
312 }
313
314 {
315 Result result;
316 TF_CHECK_OK(FastParseSingleExample(config, serialized[0], &result));
317 EXPECT_EQ(1, result.feature_stats.size());
318 EXPECT_EQ(7, result.feature_stats[0].features_count);
319 EXPECT_EQ(7, result.feature_stats[0].feature_values_count);
320 }
321 }
322 }
323
RandStr(random::SimplePhilox * rng)324 string RandStr(random::SimplePhilox* rng) {
325 static const char key_char_lookup[] =
326 "0123456789{}~`!@#$%^&*()"
327 "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
328 "abcdefghijklmnopqrstuvwxyz";
329 auto len = 1 + rng->Rand32() % 200;
330 string str;
331 str.reserve(len);
332 while (len-- > 0) {
333 str.push_back(
334 key_char_lookup[rng->Rand32() % (sizeof(key_char_lookup) /
335 sizeof(key_char_lookup[0]))]);
336 }
337 return str;
338 }
339
Fuzz(random::SimplePhilox * rng)340 void Fuzz(random::SimplePhilox* rng) {
341 // Generate keys.
342 auto num_keys = 1 + rng->Rand32() % 100;
343 std::unordered_set<string> unique_keys;
344 for (auto i = 0; i < num_keys; ++i) {
345 unique_keys.emplace(RandStr(rng));
346 }
347
348 // Generate serialized example.
349 Example example;
350 string serialized_example;
351 auto num_concats = 1 + rng->Rand32() % 4;
352 std::vector<Feature::KindCase> feat_types(
353 {Feature::kBytesList, Feature::kFloatList, Feature::kInt64List});
354 std::vector<string> all_keys(unique_keys.begin(), unique_keys.end());
355 while (num_concats--) {
356 example.Clear();
357 auto num_active_keys = 1 + rng->Rand32() % all_keys.size();
358
359 // Generate features.
360 for (auto i = 0; i < num_active_keys; ++i) {
361 auto fkey = all_keys[rng->Rand32() % all_keys.size()];
362 auto ftype_idx = rng->Rand32() % feat_types.size();
363 auto num_features = 1 + rng->Rand32() % 5;
364 switch (static_cast<Feature::KindCase>(feat_types[ftype_idx])) {
365 case Feature::kBytesList: {
366 BytesList* bytes_list =
367 (*example.mutable_features()->mutable_feature())[fkey]
368 .mutable_bytes_list();
369 while (num_features--) {
370 bytes_list->add_value(RandStr(rng));
371 }
372 break;
373 }
374 case Feature::kFloatList: {
375 FloatList* float_list =
376 (*example.mutable_features()->mutable_feature())[fkey]
377 .mutable_float_list();
378 while (num_features--) {
379 float_list->add_value(rng->RandFloat());
380 }
381 break;
382 }
383 case Feature::kInt64List: {
384 Int64List* int64_list =
385 (*example.mutable_features()->mutable_feature())[fkey]
386 .mutable_int64_list();
387 while (num_features--) {
388 int64_list->add_value(rng->Rand64());
389 }
390 break;
391 }
392 default: {
393 LOG(QFATAL);
394 break;
395 }
396 }
397 }
398 serialized_example += example.SerializeAsString();
399 }
400
401 // Test correctness.
402 TestCorrectness(serialized_example);
403 }
404
TEST(FastParse,FuzzTest)405 TEST(FastParse, FuzzTest) {
406 const uint64 seed = 1337;
407 random::PhiloxRandom philox(seed);
408 random::SimplePhilox rng(&philox);
409 auto num_runs = 200;
410 while (num_runs--) {
411 LOG(INFO) << "runs left: " << num_runs;
412 Fuzz(&rng);
413 }
414 }
415
TEST(TestFastParseExample,Empty)416 TEST(TestFastParseExample, Empty) {
417 Result result;
418 FastParseExampleConfig config;
419 config.sparse.push_back({"test", DT_STRING});
420 Status status =
421 FastParseExample(config, gtl::ArraySlice<tstring>(),
422 gtl::ArraySlice<tstring>(), nullptr, &result);
423 EXPECT_TRUE(status.ok()) << status;
424 }
425
426 } // namespace
427 } // namespace example
428 } // namespace tensorflow
429