1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/parsing/matcher.h"
18
19 #include <iostream>
20 #include <limits>
21
22 #include "utils/base/endian.h"
23 #include "utils/base/logging.h"
24 #include "utils/base/macros.h"
25 #include "utils/grammar/types.h"
26 #include "utils/strings/utf8.h"
27
28 namespace libtextclassifier3::grammar {
29 namespace {
30
31 // Iterator that just enumerates the bytes in a utf8 text.
32 struct ByteIterator {
ByteIteratorlibtextclassifier3::grammar::__anoncbebdf880111::ByteIterator33 explicit ByteIterator(StringPiece text)
34 : data(text.data()), end(text.data() + text.size()) {}
35
Nextlibtextclassifier3::grammar::__anoncbebdf880111::ByteIterator36 inline char Next() {
37 TC3_DCHECK(HasNext());
38 const char c = data[0];
39 data++;
40 return c;
41 }
HasNextlibtextclassifier3::grammar::__anoncbebdf880111::ByteIterator42 inline bool HasNext() const { return data < end; }
43
44 const char* data;
45 const char* end;
46 };
47
48 // Iterator that lowercases a utf8 string on the fly and enumerates the bytes.
49 struct LowercasingByteIterator {
LowercasingByteIteratorlibtextclassifier3::grammar::__anoncbebdf880111::LowercasingByteIterator50 LowercasingByteIterator(const UniLib* unilib, StringPiece text)
51 : unilib(*unilib),
52 data(text.data()),
53 end(text.data() + text.size()),
54 buffer_pos(0),
55 buffer_size(0) {}
56
Nextlibtextclassifier3::grammar::__anoncbebdf880111::LowercasingByteIterator57 inline char Next() {
58 // Queue next character.
59 if (buffer_pos >= buffer_size) {
60 buffer_pos = 0;
61
62 // Lower-case the next character. The character and its lower-cased
63 // counterpart may be represented with a different number of bytes in
64 // utf8.
65 buffer_size =
66 ValidRuneToChar(unilib.ToLower(ValidCharToRune(data)), buffer);
67 data += GetNumBytesForUTF8Char(data);
68 }
69 TC3_DCHECK_LT(buffer_pos, buffer_size);
70 return buffer[buffer_pos++];
71 }
72
HasNextlibtextclassifier3::grammar::__anoncbebdf880111::LowercasingByteIterator73 inline bool HasNext() const {
74 // Either we are not at the end of the data or didn't consume all bytes of
75 // the current character.
76 return (data < end || buffer_pos < buffer_size);
77 }
78
79 const UniLib& unilib;
80 const char* data;
81 const char* end;
82
83 // Each unicode codepoint can have up to 4 utf8 encoding bytes.
84 char buffer[4];
85 int buffer_pos;
86 int buffer_size;
87 };
88
89 // Searches a terminal match within a sorted table of terminals.
90 // Using `LowercasingByteIterator` allows to lower-case the query string on the
91 // fly.
92 template <typename T>
FindTerminal(T input_iterator,const char * strings,const uint32 * offsets,const int num_terminals,int * terminal_index)93 const char* FindTerminal(T input_iterator, const char* strings,
94 const uint32* offsets, const int num_terminals,
95 int* terminal_index) {
96 int left = 0;
97 int right = num_terminals;
98 int span_size = right - left;
99 int match_length = 0;
100
101 // Loop invariant:
102 // At the ith iteration, all strings in the range `left` ... `right` match the
103 // input on the first `match_length` characters.
104 while (true) {
105 const unsigned char c =
106 static_cast<const unsigned char>(input_iterator.Next());
107
108 // We find the possible range of strings in `left` ... `right` matching the
109 // `match_length` + 1 character with two binary searches:
110 // 1) `lower_bound` to find the start of the range of matching strings.
111 // 2) `upper_bound` to find the non-inclusive end of the range.
112 left =
113 (std::lower_bound(
114 offsets + left, offsets + right, c,
115 [strings, match_length](uint32 string_offset, uint32 c) -> bool {
116 return static_cast<unsigned char>(
117 strings[string_offset + match_length]) <
118 LittleEndian::ToHost32(c);
119 }) -
120 offsets);
121 right =
122 (std::upper_bound(
123 offsets + left, offsets + right, c,
124 [strings, match_length](uint32 c, uint32 string_offset) -> bool {
125 return LittleEndian::ToHost32(c) <
126 static_cast<unsigned char>(
127 strings[string_offset + match_length]);
128 }) -
129 offsets);
130 span_size = right - left;
131 if (span_size <= 0) {
132 return nullptr;
133 }
134 ++match_length;
135
136 // By the loop invariant and due to the fact that the strings are sorted,
137 // a matching string will be at `left` now.
138 if (!input_iterator.HasNext()) {
139 const int string_offset = LittleEndian::ToHost32(offsets[left]);
140 if (strings[string_offset + match_length] == 0) {
141 *terminal_index = left;
142 return &strings[string_offset];
143 }
144 return nullptr;
145 }
146 }
147
148 // No match found.
149 return nullptr;
150 }
151
152 // Finds terminal matches in the terminal rules hash tables.
153 // In case a match is found, `terminal` will be set to point into the
154 // terminals string pool.
155 template <typename T>
FindTerminalMatches(T input_iterator,const RulesSet * rules_set,const RulesSet_::Rules_::TerminalRulesMap * terminal_rules,StringPiece * terminal)156 const RulesSet_::LhsSet* FindTerminalMatches(
157 T input_iterator, const RulesSet* rules_set,
158 const RulesSet_::Rules_::TerminalRulesMap* terminal_rules,
159 StringPiece* terminal) {
160 const int terminal_size = terminal->size();
161 if (terminal_size < terminal_rules->min_terminal_length() ||
162 terminal_size > terminal_rules->max_terminal_length()) {
163 return nullptr;
164 }
165 int terminal_index;
166 if (const char* terminal_match = FindTerminal(
167 input_iterator, rules_set->terminals()->data(),
168 terminal_rules->terminal_offsets()->data(),
169 terminal_rules->terminal_offsets()->size(), &terminal_index)) {
170 *terminal = StringPiece(terminal_match, terminal->length());
171 return rules_set->lhs_set()->Get(
172 terminal_rules->lhs_set_index()->Get(terminal_index));
173 }
174 return nullptr;
175 }
176
177 // Finds unary rules matches.
FindUnaryRulesMatches(const RulesSet * rules_set,const RulesSet_::Rules * rules,const Nonterm nonterminal)178 const RulesSet_::LhsSet* FindUnaryRulesMatches(const RulesSet* rules_set,
179 const RulesSet_::Rules* rules,
180 const Nonterm nonterminal) {
181 if (!rules->unary_rules()) {
182 return nullptr;
183 }
184 if (const RulesSet_::Rules_::UnaryRulesEntry* entry =
185 rules->unary_rules()->LookupByKey(nonterminal)) {
186 return rules_set->lhs_set()->Get(entry->value());
187 }
188 return nullptr;
189 }
190
191 // Finds binary rules matches.
FindBinaryRulesMatches(const RulesSet * rules_set,const RulesSet_::Rules * rules,const TwoNonterms nonterminals)192 const RulesSet_::LhsSet* FindBinaryRulesMatches(
193 const RulesSet* rules_set, const RulesSet_::Rules* rules,
194 const TwoNonterms nonterminals) {
195 if (!rules->binary_rules()) {
196 return nullptr;
197 }
198
199 // Lookup in rules hash table.
200 const uint32 bucket_index =
201 BinaryRuleHasher()(nonterminals) % rules->binary_rules()->size();
202
203 // Get hash table bucket.
204 if (const RulesSet_::Rules_::BinaryRuleTableBucket* bucket =
205 rules->binary_rules()->Get(bucket_index)) {
206 if (bucket->rules() == nullptr) {
207 return nullptr;
208 }
209
210 // Check all entries in the chain.
211 for (const RulesSet_::Rules_::BinaryRule* rule : *bucket->rules()) {
212 if (rule->rhs_first() == nonterminals.first &&
213 rule->rhs_second() == nonterminals.second) {
214 return rules_set->lhs_set()->Get(rule->lhs_set_index());
215 }
216 }
217 }
218
219 return nullptr;
220 }
221
GetLhs(const RulesSet * rules_set,const int lhs_entry,Nonterm * nonterminal,CallbackId * callback,int64 * param,int8 * max_whitespace_gap)222 inline void GetLhs(const RulesSet* rules_set, const int lhs_entry,
223 Nonterm* nonterminal, CallbackId* callback, int64* param,
224 int8* max_whitespace_gap) {
225 if (lhs_entry > 0) {
226 // Direct encoding of the nonterminal.
227 *nonterminal = lhs_entry;
228 *callback = kNoCallback;
229 *param = 0;
230 *max_whitespace_gap = -1;
231 } else {
232 const RulesSet_::Lhs* lhs = rules_set->lhs()->Get(-lhs_entry);
233 *nonterminal = lhs->nonterminal();
234 *callback = lhs->callback_id();
235 *param = lhs->callback_param();
236 *max_whitespace_gap = lhs->max_whitespace_gap();
237 }
238 }
239
240 } // namespace
241
Finish()242 void Matcher::Finish() {
243 // Check any pending items.
244 ProcessPendingExclusionMatches();
245 }
246
QueueForProcessing(ParseTree * item)247 void Matcher::QueueForProcessing(ParseTree* item) {
248 // Push element to the front.
249 item->next = pending_items_;
250 pending_items_ = item;
251 }
252
QueueForPostCheck(ExclusionNode * item)253 void Matcher::QueueForPostCheck(ExclusionNode* item) {
254 // Push element to the front.
255 item->next = pending_exclusion_items_;
256 pending_exclusion_items_ = item;
257 }
258
AddTerminal(const CodepointSpan codepoint_span,const int match_offset,StringPiece terminal)259 void Matcher::AddTerminal(const CodepointSpan codepoint_span,
260 const int match_offset, StringPiece terminal) {
261 TC3_CHECK_GE(codepoint_span.second, last_end_);
262
263 // Finish any pending post-checks.
264 if (codepoint_span.second > last_end_) {
265 ProcessPendingExclusionMatches();
266 }
267
268 last_end_ = codepoint_span.second;
269 for (const RulesSet_::Rules* shard : rules_shards_) {
270 // Try case-sensitive matches.
271 if (const RulesSet_::LhsSet* lhs_set =
272 FindTerminalMatches(ByteIterator(terminal), rules_,
273 shard->terminal_rules(), &terminal)) {
274 // `terminal` points now into the rules string pool, providing a
275 // stable reference.
276 ExecuteLhsSet(
277 codepoint_span, match_offset,
278 /*whitespace_gap=*/(codepoint_span.first - match_offset),
279 [terminal](ParseTree* parse_tree) {
280 parse_tree->terminal = terminal.data();
281 parse_tree->rhs2 = nullptr;
282 },
283 lhs_set);
284 }
285
286 // Try case-insensitive matches.
287 if (const RulesSet_::LhsSet* lhs_set = FindTerminalMatches(
288 LowercasingByteIterator(&unilib_, terminal), rules_,
289 shard->lowercase_terminal_rules(), &terminal)) {
290 // `terminal` points now into the rules string pool, providing a
291 // stable reference.
292 ExecuteLhsSet(
293 codepoint_span, match_offset,
294 /*whitespace_gap=*/(codepoint_span.first - match_offset),
295 [terminal](ParseTree* parse_tree) {
296 parse_tree->terminal = terminal.data();
297 parse_tree->rhs2 = nullptr;
298 },
299 lhs_set);
300 }
301 }
302 ProcessPendingSet();
303 }
304
AddParseTree(ParseTree * parse_tree)305 void Matcher::AddParseTree(ParseTree* parse_tree) {
306 TC3_CHECK_GE(parse_tree->codepoint_span.second, last_end_);
307
308 // Finish any pending post-checks.
309 if (parse_tree->codepoint_span.second > last_end_) {
310 ProcessPendingExclusionMatches();
311 }
312
313 last_end_ = parse_tree->codepoint_span.second;
314 QueueForProcessing(parse_tree);
315 ProcessPendingSet();
316 }
317
ExecuteLhsSet(const CodepointSpan codepoint_span,const int match_offset_bytes,const int whitespace_gap,const std::function<void (ParseTree *)> & initializer_fn,const RulesSet_::LhsSet * lhs_set)318 void Matcher::ExecuteLhsSet(
319 const CodepointSpan codepoint_span, const int match_offset_bytes,
320 const int whitespace_gap,
321 const std::function<void(ParseTree*)>& initializer_fn,
322 const RulesSet_::LhsSet* lhs_set) {
323 TC3_CHECK(lhs_set);
324 ParseTree* parse_tree = nullptr;
325 Nonterm prev_lhs = kUnassignedNonterm;
326 for (const int32 lhs_entry : *lhs_set->lhs()) {
327 Nonterm lhs;
328 CallbackId callback_id;
329 int64 callback_param;
330 int8 max_whitespace_gap;
331 GetLhs(rules_, lhs_entry, &lhs, &callback_id, &callback_param,
332 &max_whitespace_gap);
333
334 // Check that the allowed whitespace gap limit is followed.
335 if (max_whitespace_gap >= 0 && whitespace_gap > max_whitespace_gap) {
336 continue;
337 }
338
339 // Handle callbacks.
340 switch (static_cast<DefaultCallback>(callback_id)) {
341 case DefaultCallback::kAssertion: {
342 AssertionNode* assertion_node = arena_->AllocAndInit<AssertionNode>(
343 lhs, codepoint_span, match_offset_bytes,
344 /*negative=*/(callback_param != 0));
345 initializer_fn(assertion_node);
346 QueueForProcessing(assertion_node);
347 continue;
348 }
349 case DefaultCallback::kMapping: {
350 MappingNode* mapping_node = arena_->AllocAndInit<MappingNode>(
351 lhs, codepoint_span, match_offset_bytes, /*id=*/callback_param);
352 initializer_fn(mapping_node);
353 QueueForProcessing(mapping_node);
354 continue;
355 }
356 case DefaultCallback::kExclusion: {
357 // We can only check the exclusion once all matches up to this position
358 // have been processed. Schedule and post check later.
359 ExclusionNode* exclusion_node = arena_->AllocAndInit<ExclusionNode>(
360 lhs, codepoint_span, match_offset_bytes,
361 /*exclusion_nonterm=*/callback_param);
362 initializer_fn(exclusion_node);
363 QueueForPostCheck(exclusion_node);
364 continue;
365 }
366 case DefaultCallback::kSemanticExpression: {
367 SemanticExpressionNode* expression_node =
368 arena_->AllocAndInit<SemanticExpressionNode>(
369 lhs, codepoint_span, match_offset_bytes,
370 /*expression=*/
371 rules_->semantic_expression()->Get(callback_param));
372 initializer_fn(expression_node);
373 QueueForProcessing(expression_node);
374 continue;
375 }
376 default:
377 break;
378 }
379
380 if (prev_lhs != lhs) {
381 prev_lhs = lhs;
382 parse_tree = arena_->AllocAndInit<ParseTree>(
383 lhs, codepoint_span, match_offset_bytes, ParseTree::Type::kDefault);
384 initializer_fn(parse_tree);
385 QueueForProcessing(parse_tree);
386 }
387
388 if (static_cast<DefaultCallback>(callback_id) ==
389 DefaultCallback::kRootRule) {
390 chart_.AddDerivation(Derivation{parse_tree, /*rule_id=*/callback_param});
391 }
392 }
393 }
394
ProcessPendingSet()395 void Matcher::ProcessPendingSet() {
396 while (pending_items_) {
397 // Process.
398 ParseTree* item = pending_items_;
399 pending_items_ = pending_items_->next;
400
401 // Add it to the chart.
402 chart_.Add(item);
403
404 // Check unary rules that trigger.
405 for (const RulesSet_::Rules* shard : rules_shards_) {
406 if (const RulesSet_::LhsSet* lhs_set =
407 FindUnaryRulesMatches(rules_, shard, item->lhs)) {
408 ExecuteLhsSet(
409 item->codepoint_span, item->match_offset,
410 /*whitespace_gap=*/
411 (item->codepoint_span.first - item->match_offset),
412 [item](ParseTree* parse_tree) {
413 parse_tree->rhs1 = nullptr;
414 parse_tree->rhs2 = item;
415 },
416 lhs_set);
417 }
418 }
419
420 // Check binary rules that trigger.
421 // Lookup by begin.
422 for (Chart<>::Iterator it = chart_.MatchesEndingAt(item->match_offset);
423 !it.Done(); it.Next()) {
424 const ParseTree* prev = it.Item();
425 for (const RulesSet_::Rules* shard : rules_shards_) {
426 if (const RulesSet_::LhsSet* lhs_set =
427 FindBinaryRulesMatches(rules_, shard, {prev->lhs, item->lhs})) {
428 ExecuteLhsSet(
429 /*codepoint_span=*/
430 {prev->codepoint_span.first, item->codepoint_span.second},
431 prev->match_offset,
432 /*whitespace_gap=*/
433 (item->codepoint_span.first -
434 item->match_offset), // Whitespace gap is the gap
435 // between the two parts.
436 [prev, item](ParseTree* parse_tree) {
437 parse_tree->rhs1 = prev;
438 parse_tree->rhs2 = item;
439 },
440 lhs_set);
441 }
442 }
443 }
444 }
445 }
446
ProcessPendingExclusionMatches()447 void Matcher::ProcessPendingExclusionMatches() {
448 while (pending_exclusion_items_) {
449 ExclusionNode* item = pending_exclusion_items_;
450 pending_exclusion_items_ = static_cast<ExclusionNode*>(item->next);
451
452 // Check that the exclusion condition is fulfilled.
453 if (!chart_.HasMatch(item->exclusion_nonterm, item->codepoint_span)) {
454 AddParseTree(item);
455 }
456 }
457 }
458
459 } // namespace libtextclassifier3::grammar
460