1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/parsing/parser.h"
18 
19 #include <unordered_map>
20 
21 #include "utils/grammar/parsing/parse-tree.h"
22 #include "utils/grammar/rules-utils.h"
23 #include "utils/grammar/types.h"
24 #include "utils/zlib/zlib.h"
25 #include "utils/zlib/zlib_regex.h"
26 
27 namespace libtextclassifier3::grammar {
28 namespace {
29 
CheckMemoryUsage(const UnsafeArena * arena)30 inline bool CheckMemoryUsage(const UnsafeArena* arena) {
31   // The maximum memory usage for matching.
32   constexpr int kMaxMemoryUsage = 1 << 20;
33   return arena->status().bytes_allocated() <= kMaxMemoryUsage;
34 }
35 
36 // Maps a codepoint to include the token padding if it aligns with a token
37 // start. Whitespace is ignored when symbols are fed to the matcher. Preceding
38 // whitespace is merged to the match start so that tokens and non-terminals
39 // appear next to each other without whitespace. For text or regex annotations,
40 // we therefore merge the whitespace padding to the start if the annotation
41 // starts at a token.
MapCodepointToTokenPaddingIfPresent(const std::unordered_map<CodepointIndex,CodepointIndex> & token_alignment,const int start)42 int MapCodepointToTokenPaddingIfPresent(
43     const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment,
44     const int start) {
45   const auto it = token_alignment.find(start);
46   if (it != token_alignment.end()) {
47     return it->second;
48   }
49   return start;
50 }
51 
52 }  // namespace
53 
Parser(const UniLib * unilib,const RulesSet * rules)54 Parser::Parser(const UniLib* unilib, const RulesSet* rules)
55     : unilib_(*unilib),
56       rules_(rules),
57       lexer_(unilib),
58       nonterminals_(rules_->nonterminals()),
59       rules_locales_(ParseRulesLocales(rules_)),
60       regex_annotators_(BuildRegexAnnotators()) {}
61 
62 // Uncompresses and build the defined regex annotators.
BuildRegexAnnotators() const63 std::vector<Parser::RegexAnnotator> Parser::BuildRegexAnnotators() const {
64   std::vector<RegexAnnotator> result;
65   if (rules_->regex_annotator() != nullptr) {
66     std::unique_ptr<ZlibDecompressor> decompressor =
67         ZlibDecompressor::Instance();
68     result.reserve(rules_->regex_annotator()->size());
69     for (const RulesSet_::RegexAnnotator* regex_annotator :
70          *rules_->regex_annotator()) {
71       result.push_back(
72           {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(),
73                                       regex_annotator->compressed_pattern(),
74                                       rules_->lazy_regex_compilation(),
75                                       decompressor.get()),
76            regex_annotator->nonterminal()});
77     }
78   }
79   return result;
80 }
81 
SortedSymbolsForInput(const TextContext & input,UnsafeArena * arena) const82 std::vector<Symbol> Parser::SortedSymbolsForInput(const TextContext& input,
83                                                   UnsafeArena* arena) const {
84   // Whitespace is ignored when symbols are fed to the matcher.
85   // For regex matches and existing text annotations we therefore have to merge
86   // preceding whitespace to the match start so that tokens and non-terminals
87   // appear as next to each other without whitespace. We keep track of real
88   // token starts and precending whitespace in `token_match_start`, so that we
89   // can extend a match's start to include the preceding whitespace.
90   std::unordered_map<CodepointIndex, CodepointIndex> token_match_start;
91   for (int i = input.context_span.first + 1; i < input.context_span.second;
92        i++) {
93     const CodepointIndex token_start = input.tokens[i].start;
94     const CodepointIndex prev_token_end = input.tokens[i - 1].end;
95     if (token_start != prev_token_end) {
96       token_match_start[token_start] = prev_token_end;
97     }
98   }
99 
100   std::vector<Symbol> symbols;
101   CodepointIndex match_offset = input.tokens[input.context_span.first].start;
102 
103   // Add start symbol.
104   if (input.context_span.first == 0 &&
105       nonterminals_->start_nt() != kUnassignedNonterm) {
106     match_offset = 0;
107     symbols.emplace_back(arena->AllocAndInit<ParseTree>(
108         nonterminals_->start_nt(), CodepointSpan{0, 0},
109         /*match_offset=*/0, ParseTree::Type::kDefault));
110   }
111 
112   if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) {
113     symbols.emplace_back(arena->AllocAndInit<ParseTree>(
114         nonterminals_->wordbreak_nt(),
115         CodepointSpan{match_offset, match_offset},
116         /*match_offset=*/match_offset, ParseTree::Type::kDefault));
117   }
118 
119   // Add symbols from tokens.
120   for (int i = input.context_span.first; i < input.context_span.second; i++) {
121     const Token& token = input.tokens[i];
122     lexer_.AppendTokenSymbols(token.value, /*match_offset=*/match_offset,
123                               CodepointSpan{token.start, token.end}, &symbols);
124     match_offset = token.end;
125 
126     // Add word break symbol.
127     if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) {
128       symbols.emplace_back(arena->AllocAndInit<ParseTree>(
129           nonterminals_->wordbreak_nt(),
130           CodepointSpan{match_offset, match_offset},
131           /*match_offset=*/match_offset, ParseTree::Type::kDefault));
132     }
133   }
134 
135   // Add end symbol if used by the grammar.
136   if (input.context_span.second == input.tokens.size() &&
137       nonterminals_->end_nt() != kUnassignedNonterm) {
138     symbols.emplace_back(arena->AllocAndInit<ParseTree>(
139         nonterminals_->end_nt(), CodepointSpan{match_offset, match_offset},
140         /*match_offset=*/match_offset, ParseTree::Type::kDefault));
141   }
142 
143   // Add symbols from the regex annotators.
144   const CodepointIndex context_start =
145       input.tokens[input.context_span.first].start;
146   const CodepointIndex context_end =
147       input.tokens[input.context_span.second - 1].end;
148   for (const RegexAnnotator& regex_annotator : regex_annotators_) {
149     std::unique_ptr<UniLib::RegexMatcher> regex_matcher =
150         regex_annotator.pattern->Matcher(UnicodeText::Substring(
151             input.text, context_start, context_end, /*do_copy=*/false));
152     int status = UniLib::RegexMatcher::kNoError;
153     while (regex_matcher->Find(&status) &&
154            status == UniLib::RegexMatcher::kNoError) {
155       const CodepointSpan span{regex_matcher->Start(0, &status) + context_start,
156                                regex_matcher->End(0, &status) + context_start};
157       symbols.emplace_back(arena->AllocAndInit<ParseTree>(
158           regex_annotator.nonterm, span, /*match_offset=*/
159           MapCodepointToTokenPaddingIfPresent(token_match_start, span.first),
160           ParseTree::Type::kDefault));
161     }
162   }
163 
164   // Add symbols based on annotations.
165   if (auto annotation_nonterminals = nonterminals_->annotation_nt()) {
166     for (const AnnotatedSpan& annotated_span : input.annotations) {
167       const ClassificationResult& classification =
168           annotated_span.classification.front();
169       if (auto entry = annotation_nonterminals->LookupByKey(
170               classification.collection.c_str())) {
171         symbols.emplace_back(arena->AllocAndInit<AnnotationNode>(
172             entry->value(), annotated_span.span, /*match_offset=*/
173             MapCodepointToTokenPaddingIfPresent(token_match_start,
174                                                 annotated_span.span.first),
175             &classification));
176       }
177     }
178   }
179 
180   std::sort(symbols.begin(), symbols.end(),
181             [](const Symbol& a, const Symbol& b) {
182               // Sort by increasing (end, start) position to guarantee the
183               // matcher requirement that the tokens are fed in non-decreasing
184               // end position order.
185               return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
186                      std::tie(b.codepoint_span.second, b.codepoint_span.first);
187             });
188 
189   return symbols;
190 }
191 
EmitSymbol(const Symbol & symbol,UnsafeArena * arena,Matcher * matcher) const192 void Parser::EmitSymbol(const Symbol& symbol, UnsafeArena* arena,
193                         Matcher* matcher) const {
194   if (!CheckMemoryUsage(arena)) {
195     return;
196   }
197   switch (symbol.type) {
198     case Symbol::Type::TYPE_PARSE_TREE: {
199       // Just emit the parse tree.
200       matcher->AddParseTree(symbol.parse_tree);
201       return;
202     }
203     case Symbol::Type::TYPE_DIGITS: {
204       // Emit <digits> if used by the rules.
205       if (nonterminals_->digits_nt() != kUnassignedNonterm) {
206         matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
207             nonterminals_->digits_nt(), symbol.codepoint_span,
208             symbol.match_offset, ParseTree::Type::kDefault));
209       }
210 
211       // Emit <n_digits> if used by the rules.
212       if (nonterminals_->n_digits_nt() != nullptr) {
213         const int num_digits =
214             symbol.codepoint_span.second - symbol.codepoint_span.first;
215         if (num_digits <= nonterminals_->n_digits_nt()->size()) {
216           const Nonterm n_digits_nt =
217               nonterminals_->n_digits_nt()->Get(num_digits - 1);
218           if (n_digits_nt != kUnassignedNonterm) {
219             matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
220                 nonterminals_->n_digits_nt()->Get(num_digits - 1),
221                 symbol.codepoint_span, symbol.match_offset,
222                 ParseTree::Type::kDefault));
223           }
224         }
225       }
226       break;
227     }
228     case Symbol::Type::TYPE_TERM: {
229       // Emit <uppercase_token> if used by the rules.
230       if (nonterminals_->uppercase_token_nt() != 0 &&
231           unilib_.IsUpperText(
232               UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) {
233         matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
234             nonterminals_->uppercase_token_nt(), symbol.codepoint_span,
235             symbol.match_offset, ParseTree::Type::kDefault));
236       }
237       break;
238     }
239     default:
240       break;
241   }
242 
243   // Emit the token as terminal.
244   matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset,
245                        symbol.lexeme);
246 
247   // Emit <token> if used by rules.
248   matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
249       nonterminals_->token_nt(), symbol.codepoint_span, symbol.match_offset,
250       ParseTree::Type::kDefault));
251 }
252 
253 // Parses an input text and returns the root rule derivations.
Parse(const TextContext & input,UnsafeArena * arena) const254 std::vector<Derivation> Parser::Parse(const TextContext& input,
255                                       UnsafeArena* arena) const {
256   // Check the tokens, input can be non-empty (whitespace) but have no tokens.
257   if (input.tokens.empty()) {
258     return {};
259   }
260 
261   // Select locale matching rules.
262   std::vector<const RulesSet_::Rules*> locale_rules =
263       SelectLocaleMatchingShards(rules_, rules_locales_, input.locales);
264 
265   if (locale_rules.empty()) {
266     // Nothing to do.
267     return {};
268   }
269 
270   Matcher matcher(&unilib_, rules_, locale_rules, arena);
271   for (const Symbol& symbol : SortedSymbolsForInput(input, arena)) {
272     EmitSymbol(symbol, arena, &matcher);
273   }
274   matcher.Finish();
275   return matcher.chart().derivations();
276 }
277 
278 }  // namespace libtextclassifier3::grammar
279