/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "annotator/duration/duration.h" #include #include #include "annotator/collections.h" #include "annotator/model_generated.h" #include "annotator/types-test-util.h" #include "annotator/types.h" #include "utils/test-utils.h" #include "utils/utf8/unicodetext.h" #include "utils/utf8/unilib.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace libtextclassifier3 { namespace { using testing::AllOf; using testing::ElementsAre; using testing::Field; const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() { static const flatbuffers::DetachedBuffer* options_data = []() { DurationAnnotatorOptionsT options; options.enabled = true; options.week_expressions.push_back("week"); options.week_expressions.push_back("weeks"); options.day_expressions.push_back("day"); options.day_expressions.push_back("days"); options.hour_expressions.push_back("hour"); options.hour_expressions.push_back("hours"); options.minute_expressions.push_back("minute"); options.minute_expressions.push_back("minutes"); options.second_expressions.push_back("second"); options.second_expressions.push_back("seconds"); options.filler_expressions.push_back("and"); options.filler_expressions.push_back("a"); options.filler_expressions.push_back("an"); options.filler_expressions.push_back("one"); options.half_expressions.push_back("half"); flatbuffers::FlatBufferBuilder builder; builder.Finish(DurationAnnotatorOptions::Pack(builder, &options)); return new flatbuffers::DetachedBuffer(builder.Release()); }(); return flatbuffers::GetRoot(options_data->data()); } FeatureProcessor BuildFeatureProcessor(const UniLib* unilib) { static const flatbuffers::DetachedBuffer* options_data = []() { FeatureProcessorOptionsT options; options.context_size = 1; options.max_selection_span = 1; options.snap_label_span_boundaries_to_containing_tokens = false; options.ignored_span_boundary_codepoints.push_back(','); options.tokenization_codepoint_config.emplace_back( new TokenizationCodepointRangeT()); auto& config = options.tokenization_codepoint_config.back(); config->start = 32; config->end = 33; config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; flatbuffers::FlatBufferBuilder builder; builder.Finish(FeatureProcessorOptions::Pack(builder, &options)); return new flatbuffers::DetachedBuffer(builder.Release()); }(); const FeatureProcessorOptions* feature_processor_options = flatbuffers::GetRoot(options_data->data()); return FeatureProcessor(feature_processor_options, unilib); } class DurationAnnotatorTest : public ::testing::Test { protected: DurationAnnotatorTest() : INIT_UNILIB_FOR_TESTING(unilib_), feature_processor_(BuildFeatureProcessor(&unilib_)), duration_annotator_(TestingDurationAnnotatorOptions(), &feature_processor_) {} std::vector Tokenize(const UnicodeText& text) { return feature_processor_.Tokenize(text); } UniLib unilib_; FeatureProcessor feature_processor_; DurationAnnotator duration_annotator_; }; TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) { ClassificationResult classification; EXPECT_TRUE(duration_annotator_.ClassifyText( UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24}, AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification)); EXPECT_THAT(classification, AllOf(Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 15 * 60 * 1000))); } TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) { ClassificationResult classification; EXPECT_TRUE(duration_annotator_.ClassifyText( UTF8ToUnicodeText("Wake me up in15 minutesok?"), {13, 23}, AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification)); EXPECT_THAT(classification, AllOf(Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 15 * 60 * 1000))); } TEST_F(DurationAnnotatorTest, FindsSimpleDuration) { const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); EXPECT_THAT( result, ElementsAre( AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)), Field(&AnnotatedSpan::classification, ElementsAre(AllOf( Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 15 * 60 * 1000))))))); } TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); EXPECT_THAT( result, ElementsAre( AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)), Field(&AnnotatedSpan::classification, ElementsAre(AllOf( Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 3.5 * 60 * 1000))))))); } TEST_F(DurationAnnotatorTest, FindsComposedDuration) { const UnicodeText text = UTF8ToUnicodeText("Wake me up in 3 hours and 5 seconds ok?"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); EXPECT_THAT( result, ElementsAre( AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 35)), Field(&AnnotatedSpan::classification, ElementsAre(AllOf( Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 3 * 60 * 60 * 1000 + 5 * 1000))))))); } TEST_F(DurationAnnotatorTest, FindsHalfAnHour) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); EXPECT_THAT( result, ElementsAre( AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 28)), Field(&AnnotatedSpan::classification, ElementsAre(AllOf( Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 0.5 * 60 * 60 * 1000))))))); } TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for 1 hour and a half"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); EXPECT_THAT( result, ElementsAre( AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 33)), Field(&AnnotatedSpan::classification, ElementsAre(AllOf( Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 1.5 * 60 * 60 * 1000))))))); } TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for an hour and a half"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); EXPECT_THAT( result, ElementsAre( AllOf(Field(&AnnotatedSpan::span, CodepointSpan(19, 34)), Field(&AnnotatedSpan::classification, ElementsAre(AllOf( Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 1.5 * 60 * 60 * 1000))))))); } TEST_F(DurationAnnotatorTest, FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for 10 minutes and a second ok?"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); EXPECT_THAT( result, ElementsAre( AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 39)), Field(&AnnotatedSpan::classification, ElementsAre(AllOf( Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 10 * 60 * 1000 + 1 * 1000))))))); } TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) { const UnicodeText text = UTF8ToUnicodeText( "Set a timer for a a a 10 minutes and 2 seconds an and an ok?"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); EXPECT_THAT( result, ElementsAre( AllOf(Field(&AnnotatedSpan::span, CodepointSpan(22, 46)), Field(&AnnotatedSpan::classification, ElementsAre(AllOf( Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 10 * 60 * 1000 + 2 * 1000))))))); } TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for half ok?"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); ASSERT_EQ(result.size(), 0); } TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for 10 ,minutes, ,and, ,2, seconds, ok?"); std::vector tokens = Tokenize(text); std::vector result; EXPECT_TRUE(duration_annotator_.FindAll( text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); EXPECT_THAT( result, ElementsAre( AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 46)), Field(&AnnotatedSpan::classification, ElementsAre(AllOf( Field(&ClassificationResult::collection, "duration"), Field(&ClassificationResult::duration_ms, 10 * 60 * 1000 + 2 * 1000))))))); } } // namespace } // namespace libtextclassifier3