1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/parsing/matcher.h"
18 
19 #include <iostream>
20 #include <limits>
21 
22 #include "utils/base/endian.h"
23 #include "utils/base/logging.h"
24 #include "utils/base/macros.h"
25 #include "utils/grammar/types.h"
26 #include "utils/strings/utf8.h"
27 
28 namespace libtextclassifier3::grammar {
29 namespace {
30 
31 // Iterator that just enumerates the bytes in a utf8 text.
32 struct ByteIterator {
ByteIteratorlibtextclassifier3::grammar::__anoncbebdf880111::ByteIterator33   explicit ByteIterator(StringPiece text)
34       : data(text.data()), end(text.data() + text.size()) {}
35 
Nextlibtextclassifier3::grammar::__anoncbebdf880111::ByteIterator36   inline char Next() {
37     TC3_DCHECK(HasNext());
38     const char c = data[0];
39     data++;
40     return c;
41   }
HasNextlibtextclassifier3::grammar::__anoncbebdf880111::ByteIterator42   inline bool HasNext() const { return data < end; }
43 
44   const char* data;
45   const char* end;
46 };
47 
48 // Iterator that lowercases a utf8 string on the fly and enumerates the bytes.
49 struct LowercasingByteIterator {
LowercasingByteIteratorlibtextclassifier3::grammar::__anoncbebdf880111::LowercasingByteIterator50   LowercasingByteIterator(const UniLib* unilib, StringPiece text)
51       : unilib(*unilib),
52         data(text.data()),
53         end(text.data() + text.size()),
54         buffer_pos(0),
55         buffer_size(0) {}
56 
Nextlibtextclassifier3::grammar::__anoncbebdf880111::LowercasingByteIterator57   inline char Next() {
58     // Queue next character.
59     if (buffer_pos >= buffer_size) {
60       buffer_pos = 0;
61 
62       // Lower-case the next character. The character and its lower-cased
63       // counterpart may be represented with a different number of bytes in
64       // utf8.
65       buffer_size =
66           ValidRuneToChar(unilib.ToLower(ValidCharToRune(data)), buffer);
67       data += GetNumBytesForUTF8Char(data);
68     }
69     TC3_DCHECK_LT(buffer_pos, buffer_size);
70     return buffer[buffer_pos++];
71   }
72 
HasNextlibtextclassifier3::grammar::__anoncbebdf880111::LowercasingByteIterator73   inline bool HasNext() const {
74     // Either we are not at the end of the data or didn't consume all bytes of
75     // the current character.
76     return (data < end || buffer_pos < buffer_size);
77   }
78 
79   const UniLib& unilib;
80   const char* data;
81   const char* end;
82 
83   // Each unicode codepoint can have up to 4 utf8 encoding bytes.
84   char buffer[4];
85   int buffer_pos;
86   int buffer_size;
87 };
88 
89 // Searches a terminal match within a sorted table of terminals.
90 // Using `LowercasingByteIterator` allows to lower-case the query string on the
91 // fly.
92 template <typename T>
FindTerminal(T input_iterator,const char * strings,const uint32 * offsets,const int num_terminals,int * terminal_index)93 const char* FindTerminal(T input_iterator, const char* strings,
94                          const uint32* offsets, const int num_terminals,
95                          int* terminal_index) {
96   int left = 0;
97   int right = num_terminals;
98   int span_size = right - left;
99   int match_length = 0;
100 
101   // Loop invariant:
102   // At the ith iteration, all strings in the range `left` ... `right` match the
103   // input on the first `match_length` characters.
104   while (true) {
105     const unsigned char c =
106         static_cast<const unsigned char>(input_iterator.Next());
107 
108     // We find the possible range of strings in `left` ... `right` matching the
109     // `match_length` + 1 character with two binary searches:
110     //    1) `lower_bound` to find the start of the range of matching strings.
111     //    2) `upper_bound` to find the non-inclusive end of the range.
112     left =
113         (std::lower_bound(
114              offsets + left, offsets + right, c,
115              [strings, match_length](uint32 string_offset, uint32 c) -> bool {
116                return static_cast<unsigned char>(
117                           strings[string_offset + match_length]) <
118                       LittleEndian::ToHost32(c);
119              }) -
120          offsets);
121     right =
122         (std::upper_bound(
123              offsets + left, offsets + right, c,
124              [strings, match_length](uint32 c, uint32 string_offset) -> bool {
125                return LittleEndian::ToHost32(c) <
126                       static_cast<unsigned char>(
127                           strings[string_offset + match_length]);
128              }) -
129          offsets);
130     span_size = right - left;
131     if (span_size <= 0) {
132       return nullptr;
133     }
134     ++match_length;
135 
136     // By the loop invariant and due to the fact that the strings are sorted,
137     // a matching string will be at `left` now.
138     if (!input_iterator.HasNext()) {
139       const int string_offset = LittleEndian::ToHost32(offsets[left]);
140       if (strings[string_offset + match_length] == 0) {
141         *terminal_index = left;
142         return &strings[string_offset];
143       }
144       return nullptr;
145     }
146   }
147 
148   // No match found.
149   return nullptr;
150 }
151 
152 // Finds terminal matches in the terminal rules hash tables.
153 // In case a match is found, `terminal` will be set to point into the
154 // terminals string pool.
155 template <typename T>
FindTerminalMatches(T input_iterator,const RulesSet * rules_set,const RulesSet_::Rules_::TerminalRulesMap * terminal_rules,StringPiece * terminal)156 const RulesSet_::LhsSet* FindTerminalMatches(
157     T input_iterator, const RulesSet* rules_set,
158     const RulesSet_::Rules_::TerminalRulesMap* terminal_rules,
159     StringPiece* terminal) {
160   const int terminal_size = terminal->size();
161   if (terminal_size < terminal_rules->min_terminal_length() ||
162       terminal_size > terminal_rules->max_terminal_length()) {
163     return nullptr;
164   }
165   int terminal_index;
166   if (const char* terminal_match = FindTerminal(
167           input_iterator, rules_set->terminals()->data(),
168           terminal_rules->terminal_offsets()->data(),
169           terminal_rules->terminal_offsets()->size(), &terminal_index)) {
170     *terminal = StringPiece(terminal_match, terminal->length());
171     return rules_set->lhs_set()->Get(
172         terminal_rules->lhs_set_index()->Get(terminal_index));
173   }
174   return nullptr;
175 }
176 
177 // Finds unary rules matches.
FindUnaryRulesMatches(const RulesSet * rules_set,const RulesSet_::Rules * rules,const Nonterm nonterminal)178 const RulesSet_::LhsSet* FindUnaryRulesMatches(const RulesSet* rules_set,
179                                                const RulesSet_::Rules* rules,
180                                                const Nonterm nonterminal) {
181   if (!rules->unary_rules()) {
182     return nullptr;
183   }
184   if (const RulesSet_::Rules_::UnaryRulesEntry* entry =
185           rules->unary_rules()->LookupByKey(nonterminal)) {
186     return rules_set->lhs_set()->Get(entry->value());
187   }
188   return nullptr;
189 }
190 
191 // Finds binary rules matches.
FindBinaryRulesMatches(const RulesSet * rules_set,const RulesSet_::Rules * rules,const TwoNonterms nonterminals)192 const RulesSet_::LhsSet* FindBinaryRulesMatches(
193     const RulesSet* rules_set, const RulesSet_::Rules* rules,
194     const TwoNonterms nonterminals) {
195   if (!rules->binary_rules()) {
196     return nullptr;
197   }
198 
199   // Lookup in rules hash table.
200   const uint32 bucket_index =
201       BinaryRuleHasher()(nonterminals) % rules->binary_rules()->size();
202 
203   // Get hash table bucket.
204   if (const RulesSet_::Rules_::BinaryRuleTableBucket* bucket =
205           rules->binary_rules()->Get(bucket_index)) {
206     if (bucket->rules() == nullptr) {
207       return nullptr;
208     }
209 
210     // Check all entries in the chain.
211     for (const RulesSet_::Rules_::BinaryRule* rule : *bucket->rules()) {
212       if (rule->rhs_first() == nonterminals.first &&
213           rule->rhs_second() == nonterminals.second) {
214         return rules_set->lhs_set()->Get(rule->lhs_set_index());
215       }
216     }
217   }
218 
219   return nullptr;
220 }
221 
GetLhs(const RulesSet * rules_set,const int lhs_entry,Nonterm * nonterminal,CallbackId * callback,int64 * param,int8 * max_whitespace_gap)222 inline void GetLhs(const RulesSet* rules_set, const int lhs_entry,
223                    Nonterm* nonterminal, CallbackId* callback, int64* param,
224                    int8* max_whitespace_gap) {
225   if (lhs_entry > 0) {
226     // Direct encoding of the nonterminal.
227     *nonterminal = lhs_entry;
228     *callback = kNoCallback;
229     *param = 0;
230     *max_whitespace_gap = -1;
231   } else {
232     const RulesSet_::Lhs* lhs = rules_set->lhs()->Get(-lhs_entry);
233     *nonterminal = lhs->nonterminal();
234     *callback = lhs->callback_id();
235     *param = lhs->callback_param();
236     *max_whitespace_gap = lhs->max_whitespace_gap();
237   }
238 }
239 
240 }  // namespace
241 
Finish()242 void Matcher::Finish() {
243   // Check any pending items.
244   ProcessPendingExclusionMatches();
245 }
246 
QueueForProcessing(ParseTree * item)247 void Matcher::QueueForProcessing(ParseTree* item) {
248   // Push element to the front.
249   item->next = pending_items_;
250   pending_items_ = item;
251 }
252 
QueueForPostCheck(ExclusionNode * item)253 void Matcher::QueueForPostCheck(ExclusionNode* item) {
254   // Push element to the front.
255   item->next = pending_exclusion_items_;
256   pending_exclusion_items_ = item;
257 }
258 
AddTerminal(const CodepointSpan codepoint_span,const int match_offset,StringPiece terminal)259 void Matcher::AddTerminal(const CodepointSpan codepoint_span,
260                           const int match_offset, StringPiece terminal) {
261   TC3_CHECK_GE(codepoint_span.second, last_end_);
262 
263   // Finish any pending post-checks.
264   if (codepoint_span.second > last_end_) {
265     ProcessPendingExclusionMatches();
266   }
267 
268   last_end_ = codepoint_span.second;
269   for (const RulesSet_::Rules* shard : rules_shards_) {
270     // Try case-sensitive matches.
271     if (const RulesSet_::LhsSet* lhs_set =
272             FindTerminalMatches(ByteIterator(terminal), rules_,
273                                 shard->terminal_rules(), &terminal)) {
274       // `terminal` points now into the rules string pool, providing a
275       // stable reference.
276       ExecuteLhsSet(
277           codepoint_span, match_offset,
278           /*whitespace_gap=*/(codepoint_span.first - match_offset),
279           [terminal](ParseTree* parse_tree) {
280             parse_tree->terminal = terminal.data();
281             parse_tree->rhs2 = nullptr;
282           },
283           lhs_set);
284     }
285 
286     // Try case-insensitive matches.
287     if (const RulesSet_::LhsSet* lhs_set = FindTerminalMatches(
288             LowercasingByteIterator(&unilib_, terminal), rules_,
289             shard->lowercase_terminal_rules(), &terminal)) {
290       // `terminal` points now into the rules string pool, providing a
291       // stable reference.
292       ExecuteLhsSet(
293           codepoint_span, match_offset,
294           /*whitespace_gap=*/(codepoint_span.first - match_offset),
295           [terminal](ParseTree* parse_tree) {
296             parse_tree->terminal = terminal.data();
297             parse_tree->rhs2 = nullptr;
298           },
299           lhs_set);
300     }
301   }
302   ProcessPendingSet();
303 }
304 
AddParseTree(ParseTree * parse_tree)305 void Matcher::AddParseTree(ParseTree* parse_tree) {
306   TC3_CHECK_GE(parse_tree->codepoint_span.second, last_end_);
307 
308   // Finish any pending post-checks.
309   if (parse_tree->codepoint_span.second > last_end_) {
310     ProcessPendingExclusionMatches();
311   }
312 
313   last_end_ = parse_tree->codepoint_span.second;
314   QueueForProcessing(parse_tree);
315   ProcessPendingSet();
316 }
317 
ExecuteLhsSet(const CodepointSpan codepoint_span,const int match_offset_bytes,const int whitespace_gap,const std::function<void (ParseTree *)> & initializer_fn,const RulesSet_::LhsSet * lhs_set)318 void Matcher::ExecuteLhsSet(
319     const CodepointSpan codepoint_span, const int match_offset_bytes,
320     const int whitespace_gap,
321     const std::function<void(ParseTree*)>& initializer_fn,
322     const RulesSet_::LhsSet* lhs_set) {
323   TC3_CHECK(lhs_set);
324   ParseTree* parse_tree = nullptr;
325   Nonterm prev_lhs = kUnassignedNonterm;
326   for (const int32 lhs_entry : *lhs_set->lhs()) {
327     Nonterm lhs;
328     CallbackId callback_id;
329     int64 callback_param;
330     int8 max_whitespace_gap;
331     GetLhs(rules_, lhs_entry, &lhs, &callback_id, &callback_param,
332            &max_whitespace_gap);
333 
334     // Check that the allowed whitespace gap limit is followed.
335     if (max_whitespace_gap >= 0 && whitespace_gap > max_whitespace_gap) {
336       continue;
337     }
338 
339     // Handle callbacks.
340     switch (static_cast<DefaultCallback>(callback_id)) {
341       case DefaultCallback::kAssertion: {
342         AssertionNode* assertion_node = arena_->AllocAndInit<AssertionNode>(
343             lhs, codepoint_span, match_offset_bytes,
344             /*negative=*/(callback_param != 0));
345         initializer_fn(assertion_node);
346         QueueForProcessing(assertion_node);
347         continue;
348       }
349       case DefaultCallback::kMapping: {
350         MappingNode* mapping_node = arena_->AllocAndInit<MappingNode>(
351             lhs, codepoint_span, match_offset_bytes, /*id=*/callback_param);
352         initializer_fn(mapping_node);
353         QueueForProcessing(mapping_node);
354         continue;
355       }
356       case DefaultCallback::kExclusion: {
357         // We can only check the exclusion once all matches up to this position
358         // have been processed. Schedule and post check later.
359         ExclusionNode* exclusion_node = arena_->AllocAndInit<ExclusionNode>(
360             lhs, codepoint_span, match_offset_bytes,
361             /*exclusion_nonterm=*/callback_param);
362         initializer_fn(exclusion_node);
363         QueueForPostCheck(exclusion_node);
364         continue;
365       }
366       case DefaultCallback::kSemanticExpression: {
367         SemanticExpressionNode* expression_node =
368             arena_->AllocAndInit<SemanticExpressionNode>(
369                 lhs, codepoint_span, match_offset_bytes,
370                 /*expression=*/
371                 rules_->semantic_expression()->Get(callback_param));
372         initializer_fn(expression_node);
373         QueueForProcessing(expression_node);
374         continue;
375       }
376       default:
377         break;
378     }
379 
380     if (prev_lhs != lhs) {
381       prev_lhs = lhs;
382       parse_tree = arena_->AllocAndInit<ParseTree>(
383           lhs, codepoint_span, match_offset_bytes, ParseTree::Type::kDefault);
384       initializer_fn(parse_tree);
385       QueueForProcessing(parse_tree);
386     }
387 
388     if (static_cast<DefaultCallback>(callback_id) ==
389         DefaultCallback::kRootRule) {
390       chart_.AddDerivation(Derivation{parse_tree, /*rule_id=*/callback_param});
391     }
392   }
393 }
394 
ProcessPendingSet()395 void Matcher::ProcessPendingSet() {
396   while (pending_items_) {
397     // Process.
398     ParseTree* item = pending_items_;
399     pending_items_ = pending_items_->next;
400 
401     // Add it to the chart.
402     chart_.Add(item);
403 
404     // Check unary rules that trigger.
405     for (const RulesSet_::Rules* shard : rules_shards_) {
406       if (const RulesSet_::LhsSet* lhs_set =
407               FindUnaryRulesMatches(rules_, shard, item->lhs)) {
408         ExecuteLhsSet(
409             item->codepoint_span, item->match_offset,
410             /*whitespace_gap=*/
411             (item->codepoint_span.first - item->match_offset),
412             [item](ParseTree* parse_tree) {
413               parse_tree->rhs1 = nullptr;
414               parse_tree->rhs2 = item;
415             },
416             lhs_set);
417       }
418     }
419 
420     // Check binary rules that trigger.
421     // Lookup by begin.
422     for (Chart<>::Iterator it = chart_.MatchesEndingAt(item->match_offset);
423          !it.Done(); it.Next()) {
424       const ParseTree* prev = it.Item();
425       for (const RulesSet_::Rules* shard : rules_shards_) {
426         if (const RulesSet_::LhsSet* lhs_set =
427                 FindBinaryRulesMatches(rules_, shard, {prev->lhs, item->lhs})) {
428           ExecuteLhsSet(
429               /*codepoint_span=*/
430               {prev->codepoint_span.first, item->codepoint_span.second},
431               prev->match_offset,
432               /*whitespace_gap=*/
433               (item->codepoint_span.first -
434                item->match_offset),  // Whitespace gap is the gap
435                                      // between the two parts.
436               [prev, item](ParseTree* parse_tree) {
437                 parse_tree->rhs1 = prev;
438                 parse_tree->rhs2 = item;
439               },
440               lhs_set);
441         }
442       }
443     }
444   }
445 }
446 
ProcessPendingExclusionMatches()447 void Matcher::ProcessPendingExclusionMatches() {
448   while (pending_exclusion_items_) {
449     ExclusionNode* item = pending_exclusion_items_;
450     pending_exclusion_items_ = static_cast<ExclusionNode*>(item->next);
451 
452     // Check that the exclusion condition is fulfilled.
453     if (!chart_.HasMatch(item->exclusion_nonterm, item->codepoint_span)) {
454       AddParseTree(item);
455     }
456   }
457 }
458 
459 }  // namespace libtextclassifier3::grammar
460