1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/example/feature_util.h"
16 
17 #include <vector>
18 
19 #include "tensorflow/core/example/example.pb.h"
20 #include "tensorflow/core/platform/test.h"
21 #include "tensorflow/core/platform/types.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 const float kTolerance = 1e-5;
27 
TEST(GetFeatureValuesInt64Test,ReadsASingleValue)28 TEST(GetFeatureValuesInt64Test, ReadsASingleValue) {
29   Example example;
30   (*example.mutable_features()->mutable_feature())["tag"]
31       .mutable_int64_list()
32       ->add_value(42);
33 
34   auto tag = GetFeatureValues<protobuf_int64>("tag", example);
35 
36   ASSERT_EQ(1, tag.size());
37   EXPECT_EQ(42, tag.Get(0));
38 }
39 
TEST(GetFeatureValuesInt64Test,ReadsASingleValueFromFeature)40 TEST(GetFeatureValuesInt64Test, ReadsASingleValueFromFeature) {
41   Feature feature;
42   feature.mutable_int64_list()->add_value(42);
43 
44   auto values = GetFeatureValues<protobuf_int64>(feature);
45 
46   ASSERT_EQ(1, values.size());
47   EXPECT_EQ(42, values.Get(0));
48 }
49 
TEST(GetFeatureValuesInt64Test,WritesASingleValue)50 TEST(GetFeatureValuesInt64Test, WritesASingleValue) {
51   Example example;
52 
53   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
54 
55   ASSERT_EQ(1,
56             example.features().feature().at("tag").int64_list().value_size());
57   EXPECT_EQ(42, example.features().feature().at("tag").int64_list().value(0));
58 }
59 
TEST(GetFeatureValuesInt64Test,WritesASingleValueToFeature)60 TEST(GetFeatureValuesInt64Test, WritesASingleValueToFeature) {
61   Feature feature;
62 
63   GetFeatureValues<protobuf_int64>(&feature)->Add(42);
64 
65   ASSERT_EQ(1, feature.int64_list().value_size());
66   EXPECT_EQ(42, feature.int64_list().value(0));
67 }
68 
TEST(GetFeatureValuesInt64Test,CheckUntypedFieldExistence)69 TEST(GetFeatureValuesInt64Test, CheckUntypedFieldExistence) {
70   Example example;
71   ASSERT_FALSE(HasFeature("tag", example));
72 
73   GetFeatureValues<protobuf_int64>("tag", &example)->Add(0);
74 
75   EXPECT_TRUE(HasFeature("tag", example));
76 }
77 
TEST(GetFeatureValuesInt64Test,CheckTypedFieldExistence)78 TEST(GetFeatureValuesInt64Test, CheckTypedFieldExistence) {
79   Example example;
80 
81   GetFeatureValues<float>("tag", &example)->Add(3.14);
82   ASSERT_FALSE(HasFeature<protobuf_int64>("tag", example));
83 
84   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
85 
86   EXPECT_TRUE(HasFeature<protobuf_int64>("tag", example));
87   auto tag_ro = GetFeatureValues<protobuf_int64>("tag", example);
88   ASSERT_EQ(1, tag_ro.size());
89   EXPECT_EQ(42, tag_ro.Get(0));
90 }
91 
TEST(GetFeatureValuesInt64Test,CopyIterableToAField)92 TEST(GetFeatureValuesInt64Test, CopyIterableToAField) {
93   Example example;
94   std::vector<int> values{1, 2, 3};
95 
96   std::copy(values.begin(), values.end(),
97             protobuf::RepeatedFieldBackInserter(
98                 GetFeatureValues<protobuf_int64>("tag", &example)));
99 
100   auto tag_ro = GetFeatureValues<protobuf_int64>("tag", example);
101   ASSERT_EQ(3, tag_ro.size());
102   EXPECT_EQ(1, tag_ro.Get(0));
103   EXPECT_EQ(2, tag_ro.Get(1));
104   EXPECT_EQ(3, tag_ro.Get(2));
105 }
106 
TEST(GetFeatureValuesFloatTest,ReadsASingleValueFromFeature)107 TEST(GetFeatureValuesFloatTest, ReadsASingleValueFromFeature) {
108   Feature feature;
109   feature.mutable_float_list()->add_value(3.14);
110 
111   auto values = GetFeatureValues<float>(feature);
112 
113   ASSERT_EQ(1, values.size());
114   EXPECT_NEAR(3.14, values.Get(0), kTolerance);
115 }
116 
TEST(GetFeatureValuesFloatTest,ReadsASingleValue)117 TEST(GetFeatureValuesFloatTest, ReadsASingleValue) {
118   Example example;
119   (*example.mutable_features()->mutable_feature())["tag"]
120       .mutable_float_list()
121       ->add_value(3.14);
122 
123   auto tag = GetFeatureValues<float>("tag", example);
124 
125   ASSERT_EQ(1, tag.size());
126   EXPECT_NEAR(3.14, tag.Get(0), kTolerance);
127 }
128 
TEST(GetFeatureValuesFloatTest,WritesASingleValueToFeature)129 TEST(GetFeatureValuesFloatTest, WritesASingleValueToFeature) {
130   Feature feature;
131 
132   GetFeatureValues<float>(&feature)->Add(3.14);
133 
134   ASSERT_EQ(1, feature.float_list().value_size());
135   EXPECT_NEAR(3.14, feature.float_list().value(0), kTolerance);
136 }
137 
TEST(GetFeatureValuesFloatTest,WritesASingleValue)138 TEST(GetFeatureValuesFloatTest, WritesASingleValue) {
139   Example example;
140 
141   GetFeatureValues<float>("tag", &example)->Add(3.14);
142 
143   ASSERT_EQ(1,
144             example.features().feature().at("tag").float_list().value_size());
145   EXPECT_NEAR(3.14,
146               example.features().feature().at("tag").float_list().value(0),
147               kTolerance);
148 }
149 
TEST(GetFeatureValuesFloatTest,CheckTypedFieldExistence)150 TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) {
151   Example example;
152 
153   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
154   ASSERT_FALSE(HasFeature<float>("tag", example));
155 
156   GetFeatureValues<float>("tag", &example)->Add(3.14);
157 
158   EXPECT_TRUE(HasFeature<float>("tag", example));
159   auto tag_ro = GetFeatureValues<float>("tag", example);
160   ASSERT_EQ(1, tag_ro.size());
161   EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance);
162 }
163 
TEST(GetFeatureValuesFloatTest,CheckTypedFieldExistenceForDeprecatedMethod)164 TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistenceForDeprecatedMethod) {
165   Example example;
166 
167   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
168   ASSERT_FALSE(ExampleHasFeature<float>("tag", example));
169 
170   GetFeatureValues<float>("tag", &example)->Add(3.14);
171 
172   EXPECT_TRUE(ExampleHasFeature<float>("tag", example));
173   auto tag_ro = GetFeatureValues<float>("tag", example);
174   ASSERT_EQ(1, tag_ro.size());
175   EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance);
176 }
177 
TEST(GetFeatureValuesStringTest,ReadsASingleValueFromFeature)178 TEST(GetFeatureValuesStringTest, ReadsASingleValueFromFeature) {
179   Feature feature;
180   feature.mutable_bytes_list()->add_value("FOO");
181 
182   auto values = GetFeatureValues<string>(feature);
183 
184   ASSERT_EQ(1, values.size());
185   EXPECT_EQ("FOO", values.Get(0));
186 }
187 
TEST(GetFeatureValuesStringTest,ReadsASingleValue)188 TEST(GetFeatureValuesStringTest, ReadsASingleValue) {
189   Example example;
190   (*example.mutable_features()->mutable_feature())["tag"]
191       .mutable_bytes_list()
192       ->add_value("FOO");
193 
194   auto tag = GetFeatureValues<string>("tag", example);
195 
196   ASSERT_EQ(1, tag.size());
197   EXPECT_EQ("FOO", tag.Get(0));
198 }
199 
TEST(GetFeatureValuesStringTest,WritesASingleValueToFeature)200 TEST(GetFeatureValuesStringTest, WritesASingleValueToFeature) {
201   Feature feature;
202 
203   *GetFeatureValues<string>(&feature)->Add() = "FOO";
204 
205   ASSERT_EQ(1, feature.bytes_list().value_size());
206   EXPECT_EQ("FOO", feature.bytes_list().value(0));
207 }
208 
TEST(GetFeatureValuesStringTest,WritesASingleValue)209 TEST(GetFeatureValuesStringTest, WritesASingleValue) {
210   Example example;
211 
212   *GetFeatureValues<string>("tag", &example)->Add() = "FOO";
213 
214   ASSERT_EQ(1,
215             example.features().feature().at("tag").bytes_list().value_size());
216   EXPECT_EQ("FOO",
217             example.features().feature().at("tag").bytes_list().value(0));
218 }
219 
TEST(GetFeatureValuesStringTest,CheckTypedFieldExistence)220 TEST(GetFeatureValuesStringTest, CheckTypedFieldExistence) {
221   Example example;
222 
223   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
224   ASSERT_FALSE(HasFeature<string>("tag", example));
225 
226   *GetFeatureValues<string>("tag", &example)->Add() = "FOO";
227 
228   EXPECT_TRUE(HasFeature<string>("tag", example));
229   auto tag_ro = GetFeatureValues<string>("tag", example);
230   ASSERT_EQ(1, tag_ro.size());
231   EXPECT_EQ("FOO", tag_ro.Get(0));
232 }
233 
TEST(AppendFeatureValuesTest,FloatValuesFromContainer)234 TEST(AppendFeatureValuesTest, FloatValuesFromContainer) {
235   Example example;
236 
237   std::vector<double> values{1.1, 2.2, 3.3};
238   AppendFeatureValues(values, "tag", &example);
239 
240   auto tag_ro = GetFeatureValues<float>("tag", example);
241   ASSERT_EQ(3, tag_ro.size());
242   EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
243   EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
244   EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
245 }
246 
TEST(AppendFeatureValuesTest,FloatValuesUsingInitializerList)247 TEST(AppendFeatureValuesTest, FloatValuesUsingInitializerList) {
248   Example example;
249 
250   AppendFeatureValues({1.1, 2.2, 3.3}, "tag", &example);
251 
252   auto tag_ro = GetFeatureValues<float>("tag", example);
253   ASSERT_EQ(3, tag_ro.size());
254   EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
255   EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
256   EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
257 }
258 
TEST(AppendFeatureValuesTest,Int64ValuesUsingInitializerList)259 TEST(AppendFeatureValuesTest, Int64ValuesUsingInitializerList) {
260   Example example;
261 
262   std::vector<protobuf_int64> values{1, 2, 3};
263   AppendFeatureValues(values, "tag", &example);
264 
265   auto tag_ro = GetFeatureValues<protobuf_int64>("tag", example);
266   ASSERT_EQ(3, tag_ro.size());
267   EXPECT_EQ(1, tag_ro.Get(0));
268   EXPECT_EQ(2, tag_ro.Get(1));
269   EXPECT_EQ(3, tag_ro.Get(2));
270 }
271 
TEST(AppendFeatureValuesTest,StringValuesUsingInitializerList)272 TEST(AppendFeatureValuesTest, StringValuesUsingInitializerList) {
273   Example example;
274 
275   AppendFeatureValues({"FOO", "BAR", "BAZ"}, "tag", &example);
276 
277   auto tag_ro = GetFeatureValues<string>("tag", example);
278   ASSERT_EQ(3, tag_ro.size());
279   EXPECT_EQ("FOO", tag_ro.Get(0));
280   EXPECT_EQ("BAR", tag_ro.Get(1));
281   EXPECT_EQ("BAZ", tag_ro.Get(2));
282 }
283 
TEST(AppendFeatureValuesTest,StringVariablesUsingInitializerList)284 TEST(AppendFeatureValuesTest, StringVariablesUsingInitializerList) {
285   Example example;
286 
287   string string1("FOO");
288   string string2("BAR");
289   string string3("BAZ");
290 
291   AppendFeatureValues({string1, string2, string3}, "tag", &example);
292 
293   auto tag_ro = GetFeatureValues<string>("tag", example);
294   ASSERT_EQ(3, tag_ro.size());
295   EXPECT_EQ("FOO", tag_ro.Get(0));
296   EXPECT_EQ("BAR", tag_ro.Get(1));
297   EXPECT_EQ("BAZ", tag_ro.Get(2));
298 }
299 
TEST(SequenceExampleTest,ReadsASingleValueFromContext)300 TEST(SequenceExampleTest, ReadsASingleValueFromContext) {
301   SequenceExample se;
302   (*se.mutable_context()->mutable_feature())["tag"]
303       .mutable_int64_list()
304       ->add_value(42);
305 
306   auto values = GetFeatureValues<protobuf_int64>("tag", se.context());
307 
308   ASSERT_EQ(1, values.size());
309   EXPECT_EQ(42, values.Get(0));
310 }
311 
TEST(SequenceExampleTest,WritesASingleValueToContext)312 TEST(SequenceExampleTest, WritesASingleValueToContext) {
313   SequenceExample se;
314 
315   GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42);
316 
317   ASSERT_EQ(1, se.context().feature().at("tag").int64_list().value_size());
318   EXPECT_EQ(42, se.context().feature().at("tag").int64_list().value(0));
319 }
320 
TEST(SequenceExampleTest,AppendFeatureValuesToContextSingleArg)321 TEST(SequenceExampleTest, AppendFeatureValuesToContextSingleArg) {
322   SequenceExample se;
323 
324   AppendFeatureValues({1.1, 2.2, 3.3}, "tag", se.mutable_context());
325 
326   auto tag_ro = GetFeatureValues<float>("tag", se.context());
327   ASSERT_EQ(3, tag_ro.size());
328   EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
329   EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
330   EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
331 }
332 
TEST(SequenceExampleTest,CheckTypedFieldExistence)333 TEST(SequenceExampleTest, CheckTypedFieldExistence) {
334   SequenceExample se;
335 
336   GetFeatureValues<float>("tag", se.mutable_context())->Add(3.14);
337   ASSERT_FALSE(HasFeature<protobuf_int64>("tag", se.context()));
338 
339   GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42);
340 
341   EXPECT_TRUE(HasFeature<protobuf_int64>("tag", se.context()));
342   auto tag_ro = GetFeatureValues<protobuf_int64>("tag", se.context());
343   ASSERT_EQ(1, tag_ro.size());
344   EXPECT_EQ(42, tag_ro.Get(0));
345 }
346 
TEST(SequenceExampleTest,ReturnsExistingFeatureLists)347 TEST(SequenceExampleTest, ReturnsExistingFeatureLists) {
348   SequenceExample se;
349   (*se.mutable_feature_lists()->mutable_feature_list())["tag"]
350       .mutable_feature()
351       ->Add();
352 
353   auto feature = GetFeatureList("tag", se);
354 
355   ASSERT_EQ(1, feature.size());
356 }
357 
TEST(SequenceExampleTest,CreatesNewFeatureLists)358 TEST(SequenceExampleTest, CreatesNewFeatureLists) {
359   SequenceExample se;
360 
361   GetFeatureList("tag", &se)->Add();
362 
363   EXPECT_EQ(1, se.feature_lists().feature_list().at("tag").feature_size());
364 }
365 
TEST(SequenceExampleTest,CheckFeatureListExistence)366 TEST(SequenceExampleTest, CheckFeatureListExistence) {
367   SequenceExample se;
368   ASSERT_FALSE(HasFeatureList("tag", se));
369 
370   GetFeatureList("tag", &se)->Add();
371 
372   ASSERT_TRUE(HasFeatureList("tag", se));
373 }
374 
TEST(SequenceExampleTest,AppendFeatureValuesWithInitializerList)375 TEST(SequenceExampleTest, AppendFeatureValuesWithInitializerList) {
376   SequenceExample se;
377 
378   AppendFeatureValues({1, 2, 3}, "ids", se.mutable_context());
379   AppendFeatureValues({"cam1-0", "cam2-0"},
380                       GetFeatureList("images", &se)->Add());
381   AppendFeatureValues({"cam1-1", "cam2-2"},
382                       GetFeatureList("images", &se)->Add());
383 
384   EXPECT_EQ(se.DebugString(),
385             "context {\n"
386             "  feature {\n"
387             "    key: \"ids\"\n"
388             "    value {\n"
389             "      int64_list {\n"
390             "        value: 1\n"
391             "        value: 2\n"
392             "        value: 3\n"
393             "      }\n"
394             "    }\n"
395             "  }\n"
396             "}\n"
397             "feature_lists {\n"
398             "  feature_list {\n"
399             "    key: \"images\"\n"
400             "    value {\n"
401             "      feature {\n"
402             "        bytes_list {\n"
403             "          value: \"cam1-0\"\n"
404             "          value: \"cam2-0\"\n"
405             "        }\n"
406             "      }\n"
407             "      feature {\n"
408             "        bytes_list {\n"
409             "          value: \"cam1-1\"\n"
410             "          value: \"cam2-2\"\n"
411             "        }\n"
412             "      }\n"
413             "    }\n"
414             "  }\n"
415             "}\n");
416 }
417 
TEST(SequenceExampleTest,AppendFeatureValuesWithVectors)418 TEST(SequenceExampleTest, AppendFeatureValuesWithVectors) {
419   SequenceExample se;
420 
421   std::vector<float> readings{1.0, 2.5, 5.0};
422   AppendFeatureValues(readings, GetFeatureList("movie_ratings", &se)->Add());
423 
424   EXPECT_EQ(se.DebugString(),
425             "feature_lists {\n"
426             "  feature_list {\n"
427             "    key: \"movie_ratings\"\n"
428             "    value {\n"
429             "      feature {\n"
430             "        float_list {\n"
431             "          value: 1\n"
432             "          value: 2.5\n"
433             "          value: 5\n"
434             "        }\n"
435             "      }\n"
436             "    }\n"
437             "  }\n"
438             "}\n");
439 }
440 
441 }  // namespace
442 }  // namespace tensorflow
443