1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_TEST_UTILS_H_
18 #define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_TEST_UTILS_H_
19 
20 #include <memory>
21 
22 #include "actions/test-utils.h"
23 #include "annotator/grammar/grammar-annotator.h"
24 #include "utils/flatbuffers/mutable.h"
25 #include "utils/jvm-test-utils.h"
26 #include "utils/utf8/unilib.h"
27 #include "gtest/gtest.h"
28 
29 namespace libtextclassifier3 {
30 
31 // TODO(sofian): Move this matchers to a level up library, useable for more
32 // tests in text_classifier.
33 MATCHER_P3(IsAnnotatedSpan, start, end, collection,
34            "is annotated span with begin that " +
35                ::testing::DescribeMatcher<int>(start, negation) +
36                ", end that " + ::testing::DescribeMatcher<int>(end, negation) +
37                ", collection that " +
38                ::testing::DescribeMatcher<std::string>(collection, negation)) {
39   return ::testing::ExplainMatchResult(CodepointSpan(start, end), arg.span,
40                                        result_listener) &&
41          ::testing::ExplainMatchResult(::testing::StrEq(collection),
42                                        arg.classification.front().collection,
43                                        result_listener);
44 }
45 
46 MATCHER_P(IsClassificationResult, collection,
47           "is classification result with collection that " +
48               ::testing::DescribeMatcher<std::string>(collection, negation)) {
49   return ::testing::ExplainMatchResult(::testing::StrEq(collection),
50                                        arg.collection, result_listener);
51 }
52 
53 class GrammarAnnotatorTest : public ::testing::Test {
54  protected:
GrammarAnnotatorTest()55   GrammarAnnotatorTest()
56       : unilib_(CreateUniLibForTesting()),
57         serialized_entity_data_schema_(TestEntityDataSchema()),
58         entity_data_builder_(new MutableFlatbufferBuilder(
59             flatbuffers::GetRoot<reflection::Schema>(
60                 serialized_entity_data_schema_.data()))) {}
61 
62   GrammarAnnotator CreateGrammarAnnotator(
63       const ::flatbuffers::DetachedBuffer& serialized_model);
64 
65   std::unique_ptr<UniLib> unilib_;
66   const std::string serialized_entity_data_schema_;
67   std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
68 };
69 
70 void SetTestTokenizerOptions(GrammarModelT* model);
71 
72 }  // namespace libtextclassifier3
73 
74 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_TEST_UTILS_H_
75