1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/testing/annotator.h"
18 
19 #include "utils/flatbuffers/mutable.h"
20 #include "flatbuffers/reflection.h"
21 
22 namespace libtextclassifier3 {
23 
FirstResult(const std::vector<ClassificationResult> & results)24 std::string FirstResult(const std::vector<ClassificationResult>& results) {
25   if (results.empty()) {
26     return "<INVALID RESULTS>";
27   }
28   return results[0].collection;
29 }
30 
ReadFile(const std::string & file_name)31 std::string ReadFile(const std::string& file_name) {
32   std::ifstream file_stream(file_name);
33   return std::string(std::istreambuf_iterator<char>(file_stream), {});
34 }
35 
MakePattern(const std::string & collection_name,const std::string & pattern,const bool enabled_for_classification,const bool enabled_for_selection,const bool enabled_for_annotation,const float score,const float priority_score)36 std::unique_ptr<RegexModel_::PatternT> MakePattern(
37     const std::string& collection_name, const std::string& pattern,
38     const bool enabled_for_classification, const bool enabled_for_selection,
39     const bool enabled_for_annotation, const float score,
40     const float priority_score) {
41   std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
42   result->collection_name = collection_name;
43   result->pattern = pattern;
44   // We cannot directly operate with |= on the flag, so use an int here.
45   int enabled_modes = ModeFlag_NONE;
46   if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
47   if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
48   if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
49   result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
50   result->target_classification_score = score;
51   result->priority_score = priority_score;
52   return result;
53 }
54 
55 // Shortcut function that doesn't need to specify the priority score.
MakePattern(const std::string & collection_name,const std::string & pattern,const bool enabled_for_classification,const bool enabled_for_selection,const bool enabled_for_annotation,const float score)56 std::unique_ptr<RegexModel_::PatternT> MakePattern(
57     const std::string& collection_name, const std::string& pattern,
58     const bool enabled_for_classification, const bool enabled_for_selection,
59     const bool enabled_for_annotation, const float score) {
60   return MakePattern(collection_name, pattern, enabled_for_classification,
61                      enabled_for_selection, enabled_for_annotation,
62                      /*score=*/score,
63                      /*priority_score=*/score);
64 }
65 
AddTestRegexModel(ModelT * unpacked_model)66 void AddTestRegexModel(ModelT* unpacked_model) {
67   // Add test regex models.
68   unpacked_model->regex_model->patterns.push_back(MakePattern(
69       "person_with_age", "(Barack) (?:(Obama) )?is (\\d+) years old",
70       /*enabled_for_classification=*/true,
71       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0));
72 
73   // Use meta data to generate custom serialized entity data.
74   MutableFlatbufferBuilder entity_data_builder(
75       flatbuffers::GetRoot<reflection::Schema>(
76           unpacked_model->entity_data_schema.data()));
77   RegexModel_::PatternT* pattern =
78       unpacked_model->regex_model->patterns.back().get();
79 
80   {
81     std::unique_ptr<MutableFlatbuffer> entity_data =
82         entity_data_builder.NewRoot();
83     entity_data->Set("is_alive", true);
84     pattern->serialized_entity_data = entity_data->Serialize();
85   }
86   pattern->capturing_group.emplace_back(new CapturingGroupT);
87   pattern->capturing_group.emplace_back(new CapturingGroupT);
88   pattern->capturing_group.emplace_back(new CapturingGroupT);
89   pattern->capturing_group.emplace_back(new CapturingGroupT);
90   // Group 0 is the full match, capturing groups starting at 1.
91   pattern->capturing_group[1]->entity_field_path.reset(
92       new FlatbufferFieldPathT);
93   pattern->capturing_group[1]->entity_field_path->field.emplace_back(
94       new FlatbufferFieldT);
95   pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
96       "first_name";
97   pattern->capturing_group[2]->entity_field_path.reset(
98       new FlatbufferFieldPathT);
99   pattern->capturing_group[2]->entity_field_path->field.emplace_back(
100       new FlatbufferFieldT);
101   pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
102       "last_name";
103   // Set `former_us_president` field if we match Obama.
104   {
105     std::unique_ptr<MutableFlatbuffer> entity_data =
106         entity_data_builder.NewRoot();
107     entity_data->Set("former_us_president", true);
108     pattern->capturing_group[2]->serialized_entity_data =
109         entity_data->Serialize();
110   }
111   pattern->capturing_group[3]->entity_field_path.reset(
112       new FlatbufferFieldPathT);
113   pattern->capturing_group[3]->entity_field_path->field.emplace_back(
114       new FlatbufferFieldT);
115   pattern->capturing_group[3]->entity_field_path->field.back()->field_name =
116       "age";
117 }
118 
CreateEmptyModel(const std::function<void (ModelT * model)> model_update_fn)119 std::string CreateEmptyModel(
120     const std::function<void(ModelT* model)> model_update_fn) {
121   ModelT model;
122   model_update_fn(&model);
123 
124   flatbuffers::FlatBufferBuilder builder;
125   FinishModelBuffer(builder, Model::Pack(builder, &model));
126   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
127                      builder.GetSize());
128 }
129 
130 // Create fake entity data schema meta data.
AddTestEntitySchemaData(ModelT * unpacked_model)131 void AddTestEntitySchemaData(ModelT* unpacked_model) {
132   // Cannot use object oriented API here as that is not available for the
133   // reflection schema.
134   flatbuffers::FlatBufferBuilder schema_builder;
135   std::vector<flatbuffers::Offset<reflection::Field>> fields = {
136       reflection::CreateField(
137           schema_builder,
138           /*name=*/schema_builder.CreateString("first_name"),
139           /*type=*/
140           reflection::CreateType(schema_builder,
141                                  /*base_type=*/reflection::String),
142           /*id=*/0,
143           /*offset=*/4),
144       reflection::CreateField(
145           schema_builder,
146           /*name=*/schema_builder.CreateString("is_alive"),
147           /*type=*/
148           reflection::CreateType(schema_builder,
149                                  /*base_type=*/reflection::Bool),
150           /*id=*/1,
151           /*offset=*/6),
152       reflection::CreateField(
153           schema_builder,
154           /*name=*/schema_builder.CreateString("last_name"),
155           /*type=*/
156           reflection::CreateType(schema_builder,
157                                  /*base_type=*/reflection::String),
158           /*id=*/2,
159           /*offset=*/8),
160       reflection::CreateField(
161           schema_builder,
162           /*name=*/schema_builder.CreateString("age"),
163           /*type=*/
164           reflection::CreateType(schema_builder,
165                                  /*base_type=*/reflection::Int),
166           /*id=*/3,
167           /*offset=*/10),
168       reflection::CreateField(
169           schema_builder,
170           /*name=*/schema_builder.CreateString("former_us_president"),
171           /*type=*/
172           reflection::CreateType(schema_builder,
173                                  /*base_type=*/reflection::Bool),
174           /*id=*/4,
175           /*offset=*/12)};
176   std::vector<flatbuffers::Offset<reflection::Enum>> enums;
177   std::vector<flatbuffers::Offset<reflection::Object>> objects = {
178       reflection::CreateObject(
179           schema_builder,
180           /*name=*/schema_builder.CreateString("EntityData"),
181           /*fields=*/
182           schema_builder.CreateVectorOfSortedTables(&fields))};
183   schema_builder.Finish(reflection::CreateSchema(
184       schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
185       schema_builder.CreateVectorOfSortedTables(&enums),
186       /*(unused) file_ident=*/0,
187       /*(unused) file_ext=*/0,
188       /*root_table*/ objects[0]));
189 
190   unpacked_model->entity_data_schema.assign(
191       schema_builder.GetBufferPointer(),
192       schema_builder.GetBufferPointer() + schema_builder.GetSize());
193 }
194 
MakeAnnotatedSpan(CodepointSpan span,const std::string & collection,const float score,AnnotatedSpan::Source source)195 AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
196                                 const std::string& collection,
197                                 const float score,
198                                 AnnotatedSpan::Source source) {
199   AnnotatedSpan result;
200   result.span = span;
201   result.classification.push_back({collection, score});
202   result.source = source;
203   return result;
204 }
205 
206 }  // namespace libtextclassifier3
207