1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/cached-features.h"
18 
19 #include "annotator/model-executor.h"
20 #include "utils/tensor-view.h"
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 
25 using testing::ElementsAreArray;
26 using testing::FloatEq;
27 using testing::Matcher;
28 
29 namespace libtextclassifier3 {
30 namespace {
31 
ElementsAreFloat(const std::vector<float> & values)32 Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
33   std::vector<Matcher<float>> matchers;
34   for (const float value : values) {
35     matchers.push_back(FloatEq(value));
36   }
37   return ElementsAreArray(matchers);
38 }
39 
MakeFeatures(int num_tokens)40 std::unique_ptr<std::vector<float>> MakeFeatures(int num_tokens) {
41   std::unique_ptr<std::vector<float>> features(new std::vector<float>());
42   for (int i = 1; i <= num_tokens; ++i) {
43     features->push_back(i * 11.0f);
44     features->push_back(-i * 11.0f);
45     features->push_back(i * 0.1f);
46   }
47   return features;
48 }
49 
GetCachedClickContextFeatures(const CachedFeatures & cached_features,int click_pos)50 std::vector<float> GetCachedClickContextFeatures(
51     const CachedFeatures& cached_features, int click_pos) {
52   std::vector<float> output_features;
53   cached_features.AppendClickContextFeaturesForClick(click_pos,
54                                                      &output_features);
55   return output_features;
56 }
57 
GetCachedBoundsSensitiveFeatures(const CachedFeatures & cached_features,TokenSpan selected_span)58 std::vector<float> GetCachedBoundsSensitiveFeatures(
59     const CachedFeatures& cached_features, TokenSpan selected_span) {
60   std::vector<float> output_features;
61   cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span,
62                                                        &output_features);
63   return output_features;
64 }
65 
TEST(CachedFeaturesTest,ClickContext)66 TEST(CachedFeaturesTest, ClickContext) {
67   FeatureProcessorOptionsT options;
68   options.context_size = 2;
69   options.feature_version = 1;
70   flatbuffers::FlatBufferBuilder builder;
71   builder.Finish(CreateFeatureProcessorOptions(builder, &options));
72   flatbuffers::DetachedBuffer options_fb = builder.Release();
73 
74   std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
75   std::unique_ptr<std::vector<float>> padding_features(
76       new std::vector<float>{112233.0, -112233.0, 321.0});
77 
78   const std::unique_ptr<CachedFeatures> cached_features =
79       CachedFeatures::Create(
80           {3, 10}, std::move(features), std::move(padding_features),
81           flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
82           /*feature_vector_size=*/3);
83   ASSERT_TRUE(cached_features);
84 
85   EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5),
86               ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0,
87                                 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5}));
88 
89   EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6),
90               ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0,
91                                 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6}));
92 
93   EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7),
94               ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0,
95                                 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7}));
96 }
97 
TEST(CachedFeaturesTest,BoundsSensitive)98 TEST(CachedFeaturesTest, BoundsSensitive) {
99   std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config(
100       new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
101   config->enabled = true;
102   config->num_tokens_before = 2;
103   config->num_tokens_inside_left = 2;
104   config->num_tokens_inside_right = 2;
105   config->num_tokens_after = 2;
106   config->include_inside_bag = true;
107   config->include_inside_length = true;
108   FeatureProcessorOptionsT options;
109   options.bounds_sensitive_features = std::move(config);
110   options.feature_version = 2;
111   flatbuffers::FlatBufferBuilder builder;
112   builder.Finish(CreateFeatureProcessorOptions(builder, &options));
113   flatbuffers::DetachedBuffer options_fb = builder.Release();
114 
115   std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
116   std::unique_ptr<std::vector<float>> padding_features(
117       new std::vector<float>{112233.0, -112233.0, 321.0});
118 
119   const std::unique_ptr<CachedFeatures> cached_features =
120       CachedFeatures::Create(
121           {3, 9}, std::move(features), std::move(padding_features),
122           flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
123           /*feature_vector_size=*/3);
124   ASSERT_TRUE(cached_features);
125 
126   EXPECT_THAT(
127       GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}),
128       ElementsAreFloat({11.0,     -11.0,     0.1,   22.0,  -22.0, 0.2,   33.0,
129                         -33.0,    0.3,       44.0,  -44.0, 0.4,   44.0,  -44.0,
130                         0.4,      55.0,      -55.0, 0.5,   66.0,  -66.0, 0.6,
131                         112233.0, -112233.0, 321.0, 44.0,  -44.0, 0.4,   3.0}));
132 
133   EXPECT_THAT(
134       GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}),
135       ElementsAreFloat({11.0,  -11.0, 0.1,   22.0,  -22.0, 0.2,   33.0,
136                         -33.0, 0.3,   44.0,  -44.0, 0.4,   33.0,  -33.0,
137                         0.3,   44.0,  -44.0, 0.4,   55.0,  -55.0, 0.5,
138                         66.0,  -66.0, 0.6,   38.5,  -38.5, 0.35,  2.0}));
139 
140   EXPECT_THAT(
141       GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}),
142       ElementsAreFloat({22.0,     -22.0,     0.2,   33.0,  -33.0, 0.3,   44.0,
143                         -44.0,    0.4,       55.0,  -55.0, 0.5,   44.0,  -44.0,
144                         0.4,      55.0,      -55.0, 0.5,   66.0,  -66.0, 0.6,
145                         112233.0, -112233.0, 321.0, 49.5,  -49.5, 0.45,  2.0}));
146 
147   EXPECT_THAT(
148       GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}),
149       ElementsAreFloat({22.0,     -22.0,     0.2,   33.0,     -33.0,     0.3,
150                         44.0,     -44.0,     0.4,   112233.0, -112233.0, 321.0,
151                         112233.0, -112233.0, 321.0, 44.0,     -44.0,     0.4,
152                         55.0,     -55.0,     0.5,   66.0,     -66.0,     0.6,
153                         44.0,     -44.0,     0.4,   1.0}));
154 }
155 
156 }  // namespace
157 }  // namespace libtextclassifier3
158