/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "annotator/cached-features.h" #include "annotator/model-executor.h" #include "utils/tensor-view.h" #include "gmock/gmock.h" #include "gtest/gtest.h" using testing::ElementsAreArray; using testing::FloatEq; using testing::Matcher; namespace libtextclassifier3 { namespace { Matcher> ElementsAreFloat(const std::vector& values) { std::vector> matchers; for (const float value : values) { matchers.push_back(FloatEq(value)); } return ElementsAreArray(matchers); } std::unique_ptr> MakeFeatures(int num_tokens) { std::unique_ptr> features(new std::vector()); for (int i = 1; i <= num_tokens; ++i) { features->push_back(i * 11.0f); features->push_back(-i * 11.0f); features->push_back(i * 0.1f); } return features; } std::vector GetCachedClickContextFeatures( const CachedFeatures& cached_features, int click_pos) { std::vector output_features; cached_features.AppendClickContextFeaturesForClick(click_pos, &output_features); return output_features; } std::vector GetCachedBoundsSensitiveFeatures( const CachedFeatures& cached_features, TokenSpan selected_span) { std::vector output_features; cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span, &output_features); return output_features; } TEST(CachedFeaturesTest, ClickContext) { FeatureProcessorOptionsT options; options.context_size = 2; options.feature_version = 1; flatbuffers::FlatBufferBuilder builder; builder.Finish(CreateFeatureProcessorOptions(builder, &options)); flatbuffers::DetachedBuffer options_fb = builder.Release(); std::unique_ptr> features = MakeFeatures(9); std::unique_ptr> padding_features( new std::vector{112233.0, -112233.0, 321.0}); const std::unique_ptr cached_features = CachedFeatures::Create( {3, 10}, std::move(features), std::move(padding_features), flatbuffers::GetRoot(options_fb.data()), /*feature_vector_size=*/3); ASSERT_TRUE(cached_features); EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5), ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5})); EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6), ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6})); EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7), ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7})); } TEST(CachedFeaturesTest, BoundsSensitive) { std::unique_ptr config( new FeatureProcessorOptions_::BoundsSensitiveFeaturesT()); config->enabled = true; config->num_tokens_before = 2; config->num_tokens_inside_left = 2; config->num_tokens_inside_right = 2; config->num_tokens_after = 2; config->include_inside_bag = true; config->include_inside_length = true; FeatureProcessorOptionsT options; options.bounds_sensitive_features = std::move(config); options.feature_version = 2; flatbuffers::FlatBufferBuilder builder; builder.Finish(CreateFeatureProcessorOptions(builder, &options)); flatbuffers::DetachedBuffer options_fb = builder.Release(); std::unique_ptr> features = MakeFeatures(9); std::unique_ptr> padding_features( new std::vector{112233.0, -112233.0, 321.0}); const std::unique_ptr cached_features = CachedFeatures::Create( {3, 9}, std::move(features), std::move(padding_features), flatbuffers::GetRoot(options_fb.data()), /*feature_vector_size=*/3); ASSERT_TRUE(cached_features); EXPECT_THAT( GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}), ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 3.0})); EXPECT_THAT( GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}), ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, 38.5, -38.5, 0.35, 2.0})); EXPECT_THAT( GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}), ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, 112233.0, -112233.0, 321.0, 49.5, -49.5, 0.45, 2.0})); EXPECT_THAT( GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}), ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0, 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, 44.0, -44.0, 0.4, 1.0})); } } // namespace } // namespace libtextclassifier3