1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/grammar/grammar-annotator.h"
18 
19 #include "annotator/feature-processor.h"
20 #include "annotator/grammar/utils.h"
21 #include "annotator/types.h"
22 #include "utils/base/arena.h"
23 #include "utils/base/logging.h"
24 #include "utils/normalization.h"
25 #include "utils/optional.h"
26 #include "utils/utf8/unicodetext.h"
27 
28 namespace libtextclassifier3 {
29 namespace {
30 
31 // Retrieves all capturing nodes from a parse tree.
GetCapturingNodes(const grammar::ParseTree * parse_tree)32 std::unordered_map<uint16, const grammar::ParseTree*> GetCapturingNodes(
33     const grammar::ParseTree* parse_tree) {
34   std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes;
35   for (const grammar::MappingNode* mapping_node :
36        grammar::SelectAllOfType<grammar::MappingNode>(
37            parse_tree, grammar::ParseTree::Type::kMapping)) {
38     capturing_nodes[mapping_node->id] = mapping_node;
39   }
40   return capturing_nodes;
41 }
42 
43 // Computes the selection boundaries from a parse tree.
MatchSelectionBoundaries(const grammar::ParseTree * parse_tree,const GrammarModel_::RuleClassificationResult * classification)44 CodepointSpan MatchSelectionBoundaries(
45     const grammar::ParseTree* parse_tree,
46     const GrammarModel_::RuleClassificationResult* classification) {
47   if (classification->capturing_group() == nullptr) {
48     // Use full match as selection span.
49     return parse_tree->codepoint_span;
50   }
51 
52   // Set information from capturing matches.
53   CodepointSpan span{kInvalidIndex, kInvalidIndex};
54   std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes =
55       GetCapturingNodes(parse_tree);
56 
57   // Compute span boundaries.
58   for (int i = 0; i < classification->capturing_group()->size(); i++) {
59     auto it = capturing_nodes.find(i);
60     if (it == capturing_nodes.end()) {
61       // Capturing group is not active, skip.
62       continue;
63     }
64     const CapturingGroup* group = classification->capturing_group()->Get(i);
65     if (group->extend_selection()) {
66       if (span.first == kInvalidIndex) {
67         span = it->second->codepoint_span;
68       } else {
69         span.first = std::min(span.first, it->second->codepoint_span.first);
70         span.second = std::max(span.second, it->second->codepoint_span.second);
71       }
72     }
73   }
74   return span;
75 }
76 
77 }  // namespace
78 
GrammarAnnotator(const UniLib * unilib,const GrammarModel * model,const MutableFlatbufferBuilder * entity_data_builder)79 GrammarAnnotator::GrammarAnnotator(
80     const UniLib* unilib, const GrammarModel* model,
81     const MutableFlatbufferBuilder* entity_data_builder)
82     : unilib_(*unilib),
83       model_(model),
84       tokenizer_(BuildTokenizer(unilib, model->tokenizer_options())),
85       entity_data_builder_(entity_data_builder),
86       analyzer_(unilib, model->rules(), &tokenizer_) {}
87 
88 // Filters out results that do not overlap with a reference span.
OverlappingDerivations(const CodepointSpan & selection,const std::vector<grammar::Derivation> & derivations,const bool only_exact_overlap) const89 std::vector<grammar::Derivation> GrammarAnnotator::OverlappingDerivations(
90     const CodepointSpan& selection,
91     const std::vector<grammar::Derivation>& derivations,
92     const bool only_exact_overlap) const {
93   std::vector<grammar::Derivation> result;
94   for (const grammar::Derivation& derivation : derivations) {
95     // Discard matches that do not match the selection.
96     // Simple check.
97     if (!SpansOverlap(selection, derivation.parse_tree->codepoint_span)) {
98       continue;
99     }
100 
101     // Compute exact selection boundaries (without assertions and
102     // non-capturing parts).
103     const CodepointSpan span = MatchSelectionBoundaries(
104         derivation.parse_tree,
105         model_->rule_classification_result()->Get(derivation.rule_id));
106     if (!SpansOverlap(selection, span) ||
107         (only_exact_overlap && span != selection)) {
108       continue;
109     }
110     result.push_back(derivation);
111   }
112   return result;
113 }
114 
InstantiateAnnotatedSpanFromDerivation(const grammar::TextContext & input_context,const grammar::ParseTree * parse_tree,const GrammarModel_::RuleClassificationResult * interpretation,AnnotatedSpan * result) const115 bool GrammarAnnotator::InstantiateAnnotatedSpanFromDerivation(
116     const grammar::TextContext& input_context,
117     const grammar::ParseTree* parse_tree,
118     const GrammarModel_::RuleClassificationResult* interpretation,
119     AnnotatedSpan* result) const {
120   result->span = MatchSelectionBoundaries(parse_tree, interpretation);
121   ClassificationResult classification;
122   if (!InstantiateClassificationFromDerivation(
123           input_context, parse_tree, interpretation, &classification)) {
124     return false;
125   }
126   result->classification.push_back(classification);
127   return true;
128 }
129 
130 // Instantiates a classification result from a rule match.
InstantiateClassificationFromDerivation(const grammar::TextContext & input_context,const grammar::ParseTree * parse_tree,const GrammarModel_::RuleClassificationResult * interpretation,ClassificationResult * classification) const131 bool GrammarAnnotator::InstantiateClassificationFromDerivation(
132     const grammar::TextContext& input_context,
133     const grammar::ParseTree* parse_tree,
134     const GrammarModel_::RuleClassificationResult* interpretation,
135     ClassificationResult* classification) const {
136   classification->collection = interpretation->collection_name()->str();
137   classification->score = interpretation->target_classification_score();
138   classification->priority_score = interpretation->priority_score();
139 
140   // Assemble entity data.
141   if (entity_data_builder_ == nullptr) {
142     return true;
143   }
144   std::unique_ptr<MutableFlatbuffer> entity_data =
145       entity_data_builder_->NewRoot();
146   if (interpretation->serialized_entity_data() != nullptr) {
147     entity_data->MergeFromSerializedFlatbuffer(
148         StringPiece(interpretation->serialized_entity_data()->data(),
149                     interpretation->serialized_entity_data()->size()));
150   }
151   if (interpretation->entity_data() != nullptr) {
152     entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
153         interpretation->entity_data()));
154   }
155 
156   // Populate entity data from the capturing matches.
157   if (interpretation->capturing_group() != nullptr) {
158     // Gather active capturing matches.
159     std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes =
160         GetCapturingNodes(parse_tree);
161 
162     for (int i = 0; i < interpretation->capturing_group()->size(); i++) {
163       auto it = capturing_nodes.find(i);
164       if (it == capturing_nodes.end()) {
165         // Capturing group is not active, skip.
166         continue;
167       }
168       const CapturingGroup* group = interpretation->capturing_group()->Get(i);
169 
170       // Add static entity data.
171       if (group->serialized_entity_data() != nullptr) {
172         entity_data->MergeFromSerializedFlatbuffer(
173             StringPiece(interpretation->serialized_entity_data()->data(),
174                         interpretation->serialized_entity_data()->size()));
175       }
176 
177       // Set entity field from captured text.
178       if (group->entity_field_path() != nullptr) {
179         const grammar::ParseTree* capturing_match = it->second;
180         UnicodeText match_text =
181             input_context.Span(capturing_match->codepoint_span);
182         if (group->normalization_options() != nullptr) {
183           match_text = NormalizeText(unilib_, group->normalization_options(),
184                                      match_text);
185         }
186         if (!entity_data->ParseAndSet(group->entity_field_path(),
187                                       match_text.ToUTF8String())) {
188           TC3_LOG(ERROR) << "Could not set entity data from capturing match.";
189           return false;
190         }
191       }
192     }
193   }
194 
195   if (entity_data && entity_data->HasExplicitlySetFields()) {
196     classification->serialized_entity_data = entity_data->Serialize();
197   }
198   return true;
199 }
200 
Annotate(const std::vector<Locale> & locales,const UnicodeText & text,std::vector<AnnotatedSpan> * result) const201 bool GrammarAnnotator::Annotate(const std::vector<Locale>& locales,
202                                 const UnicodeText& text,
203                                 std::vector<AnnotatedSpan>* result) const {
204   grammar::TextContext input_context =
205       analyzer_.BuildTextContextForInput(text, locales);
206 
207   UnsafeArena arena(/*block_size=*/16 << 10);
208 
209   for (const grammar::Derivation& derivation : ValidDeduplicatedDerivations(
210            analyzer_.parser().Parse(input_context, &arena))) {
211     const GrammarModel_::RuleClassificationResult* interpretation =
212         model_->rule_classification_result()->Get(derivation.rule_id);
213     if ((interpretation->enabled_modes() & ModeFlag_ANNOTATION) == 0) {
214       continue;
215     }
216     result->emplace_back();
217     if (!InstantiateAnnotatedSpanFromDerivation(
218             input_context, derivation.parse_tree, interpretation,
219             &result->back())) {
220       return false;
221     }
222   }
223 
224   return true;
225 }
226 
SuggestSelection(const std::vector<Locale> & locales,const UnicodeText & text,const CodepointSpan & selection,AnnotatedSpan * result) const227 bool GrammarAnnotator::SuggestSelection(const std::vector<Locale>& locales,
228                                         const UnicodeText& text,
229                                         const CodepointSpan& selection,
230                                         AnnotatedSpan* result) const {
231   if (!selection.IsValid() || selection.IsEmpty()) {
232     return false;
233   }
234 
235   grammar::TextContext input_context =
236       analyzer_.BuildTextContextForInput(text, locales);
237 
238   UnsafeArena arena(/*block_size=*/16 << 10);
239 
240   const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr;
241   const grammar::ParseTree* best_match = nullptr;
242   for (const grammar::Derivation& derivation :
243        ValidDeduplicatedDerivations(OverlappingDerivations(
244            selection, analyzer_.parser().Parse(input_context, &arena),
245            /*only_exact_overlap=*/false))) {
246     const GrammarModel_::RuleClassificationResult* interpretation =
247         model_->rule_classification_result()->Get(derivation.rule_id);
248     if ((interpretation->enabled_modes() & ModeFlag_SELECTION) == 0) {
249       continue;
250     }
251     if (best_interpretation == nullptr ||
252         interpretation->priority_score() >
253             best_interpretation->priority_score()) {
254       best_interpretation = interpretation;
255       best_match = derivation.parse_tree;
256     }
257   }
258 
259   if (best_interpretation == nullptr) {
260     return false;
261   }
262 
263   return InstantiateAnnotatedSpanFromDerivation(input_context, best_match,
264                                                 best_interpretation, result);
265 }
266 
ClassifyText(const std::vector<Locale> & locales,const UnicodeText & text,const CodepointSpan & selection,ClassificationResult * classification_result) const267 bool GrammarAnnotator::ClassifyText(
268     const std::vector<Locale>& locales, const UnicodeText& text,
269     const CodepointSpan& selection,
270     ClassificationResult* classification_result) const {
271   if (!selection.IsValid() || selection.IsEmpty()) {
272     // Nothing to do.
273     return false;
274   }
275 
276   grammar::TextContext input_context =
277       analyzer_.BuildTextContextForInput(text, locales);
278 
279   if (const TokenSpan context_span = CodepointSpanToTokenSpan(
280           input_context.tokens, selection,
281           /*snap_boundaries_to_containing_tokens=*/true);
282       context_span.IsValid()) {
283     if (model_->context_left_num_tokens() != kInvalidIndex) {
284       input_context.context_span.first =
285           std::max(0, context_span.first - model_->context_left_num_tokens());
286     }
287     if (model_->context_right_num_tokens() != kInvalidIndex) {
288       input_context.context_span.second =
289           std::min(static_cast<int>(input_context.tokens.size()),
290                    context_span.second + model_->context_right_num_tokens());
291     }
292   }
293 
294   UnsafeArena arena(/*block_size=*/16 << 10);
295 
296   const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr;
297   const grammar::ParseTree* best_match = nullptr;
298   for (const grammar::Derivation& derivation :
299        ValidDeduplicatedDerivations(OverlappingDerivations(
300            selection, analyzer_.parser().Parse(input_context, &arena),
301            /*only_exact_overlap=*/true))) {
302     const GrammarModel_::RuleClassificationResult* interpretation =
303         model_->rule_classification_result()->Get(derivation.rule_id);
304     if ((interpretation->enabled_modes() & ModeFlag_CLASSIFICATION) == 0) {
305       continue;
306     }
307     if (best_interpretation == nullptr ||
308         interpretation->priority_score() >
309             best_interpretation->priority_score()) {
310       best_interpretation = interpretation;
311       best_match = derivation.parse_tree;
312     }
313   }
314 
315   if (best_interpretation == nullptr) {
316     return false;
317   }
318 
319   return InstantiateClassificationFromDerivation(
320       input_context, best_match, best_interpretation, classification_result);
321 }
322 
323 }  // namespace libtextclassifier3
324