1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/utils/ir.h"
18 
19 #include "utils/i18n/locale.h"
20 #include "utils/strings/append.h"
21 #include "utils/strings/stringpiece.h"
22 #include "utils/zlib/zlib.h"
23 
24 namespace libtextclassifier3::grammar {
25 namespace {
26 
27 constexpr size_t kMaxHashTableSize = 100;
28 
29 template <typename T>
SortForBinarySearchLookup(T * entries)30 void SortForBinarySearchLookup(T* entries) {
31   std::sort(entries->begin(), entries->end(),
32             [](const auto& a, const auto& b) { return a->key < b->key; });
33 }
34 
35 template <typename T>
SortStructsForBinarySearchLookup(T * entries)36 void SortStructsForBinarySearchLookup(T* entries) {
37   std::sort(entries->begin(), entries->end(),
38             [](const auto& a, const auto& b) { return a.key() < b.key(); });
39 }
40 
IsSameLhs(const Ir::Lhs & lhs,const RulesSet_::Lhs & other)41 bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) {
42   return (lhs.nonterminal == other.nonterminal() &&
43           lhs.callback.id == other.callback_id() &&
44           lhs.callback.param == other.callback_param() &&
45           lhs.preconditions.max_whitespace_gap == other.max_whitespace_gap());
46 }
47 
IsSameLhsEntry(const Ir::Lhs & lhs,const int32 lhs_entry,const std::vector<RulesSet_::Lhs> & candidates)48 bool IsSameLhsEntry(const Ir::Lhs& lhs, const int32 lhs_entry,
49                     const std::vector<RulesSet_::Lhs>& candidates) {
50   // Simple case: direct encoding of the nonterminal.
51   if (lhs_entry > 0) {
52     return (lhs.nonterminal == lhs_entry && lhs.callback.id == kNoCallback &&
53             lhs.preconditions.max_whitespace_gap == -1);
54   }
55 
56   // Entry is index into callback lookup.
57   return IsSameLhs(lhs, candidates[-lhs_entry]);
58 }
59 
IsSameLhsSet(const Ir::LhsSet & lhs_set,const RulesSet_::LhsSetT & candidate,const std::vector<RulesSet_::Lhs> & candidates)60 bool IsSameLhsSet(const Ir::LhsSet& lhs_set,
61                   const RulesSet_::LhsSetT& candidate,
62                   const std::vector<RulesSet_::Lhs>& candidates) {
63   if (lhs_set.size() != candidate.lhs.size()) {
64     return false;
65   }
66 
67   for (int i = 0; i < lhs_set.size(); i++) {
68     // Check that entries are the same.
69     if (!IsSameLhsEntry(lhs_set[i], candidate.lhs[i], candidates)) {
70       return false;
71     }
72   }
73 
74   return true;
75 }
76 
SortedLhsSet(const Ir::LhsSet & lhs_set)77 Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
78   Ir::LhsSet sorted_lhs = lhs_set;
79   std::sort(sorted_lhs.begin(), sorted_lhs.end(),
80             [](const Ir::Lhs& a, const Ir::Lhs& b) {
81               return std::tie(a.nonterminal, a.callback.id, a.callback.param,
82                               a.preconditions.max_whitespace_gap) <
83                      std::tie(b.nonterminal, b.callback.id, b.callback.param,
84                               b.preconditions.max_whitespace_gap);
85             });
86   return lhs_set;
87 }
88 
89 // Adds a new lhs match set to the output.
90 // Reuses the same set, if it was previously observed.
AddLhsSet(const Ir::LhsSet & lhs_set,RulesSetT * rules_set)91 int AddLhsSet(const Ir::LhsSet& lhs_set, RulesSetT* rules_set) {
92   Ir::LhsSet sorted_lhs = SortedLhsSet(lhs_set);
93   // Check whether we can reuse an entry.
94   const int output_size = rules_set->lhs_set.size();
95   for (int i = 0; i < output_size; i++) {
96     if (IsSameLhsSet(lhs_set, *rules_set->lhs_set[i], rules_set->lhs)) {
97       return i;
98     }
99   }
100 
101   // Add new entry.
102   rules_set->lhs_set.emplace_back(std::make_unique<RulesSet_::LhsSetT>());
103   RulesSet_::LhsSetT* serialized_lhs_set = rules_set->lhs_set.back().get();
104   for (const Ir::Lhs& lhs : lhs_set) {
105     // Simple case: No callback and no special requirements, we directly encode
106     // the nonterminal.
107     if (lhs.callback.id == kNoCallback &&
108         lhs.preconditions.max_whitespace_gap < 0) {
109       serialized_lhs_set->lhs.push_back(lhs.nonterminal);
110     } else {
111       // Check whether we can reuse a callback entry.
112       const int lhs_size = rules_set->lhs.size();
113       bool found_entry = false;
114       for (int i = 0; i < lhs_size; i++) {
115         if (IsSameLhs(lhs, rules_set->lhs[i])) {
116           found_entry = true;
117           serialized_lhs_set->lhs.push_back(-i);
118           break;
119         }
120       }
121 
122       // We could reuse an existing entry.
123       if (found_entry) {
124         continue;
125       }
126 
127       // Add a new one.
128       rules_set->lhs.push_back(
129           RulesSet_::Lhs(lhs.nonterminal, lhs.callback.id, lhs.callback.param,
130                          lhs.preconditions.max_whitespace_gap));
131       serialized_lhs_set->lhs.push_back(-lhs_size);
132     }
133   }
134   return output_size;
135 }
136 
137 // Serializes a unary rules table.
SerializeUnaryRulesShard(const std::unordered_map<Nonterm,Ir::LhsSet> & unary_rules,RulesSetT * rules_set,RulesSet_::RulesT * rules)138 void SerializeUnaryRulesShard(
139     const std::unordered_map<Nonterm, Ir::LhsSet>& unary_rules,
140     RulesSetT* rules_set, RulesSet_::RulesT* rules) {
141   for (const auto& it : unary_rules) {
142     rules->unary_rules.push_back(RulesSet_::Rules_::UnaryRulesEntry(
143         it.first, AddLhsSet(it.second, rules_set)));
144   }
145   SortStructsForBinarySearchLookup(&rules->unary_rules);
146 }
147 
148 // // Serializes a binary rules table.
SerializeBinaryRulesShard(const std::unordered_map<TwoNonterms,Ir::LhsSet,BinaryRuleHasher> & binary_rules,RulesSetT * rules_set,RulesSet_::RulesT * rules)149 void SerializeBinaryRulesShard(
150     const std::unordered_map<TwoNonterms, Ir::LhsSet, BinaryRuleHasher>&
151         binary_rules,
152     RulesSetT* rules_set, RulesSet_::RulesT* rules) {
153   const size_t num_buckets = std::min(binary_rules.size(), kMaxHashTableSize);
154   for (int i = 0; i < num_buckets; i++) {
155     rules->binary_rules.emplace_back(
156         new RulesSet_::Rules_::BinaryRuleTableBucketT());
157   }
158 
159   // Serialize the table.
160   BinaryRuleHasher hash;
161   for (const auto& it : binary_rules) {
162     const TwoNonterms key = it.first;
163     uint32 bucket_index = hash(key) % num_buckets;
164 
165     // Add entry to bucket chain list.
166     rules->binary_rules[bucket_index]->rules.push_back(
167         RulesSet_::Rules_::BinaryRule(key.first, key.second,
168                                       AddLhsSet(it.second, rules_set)));
169   }
170 }
171 
172 }  // namespace
173 
AddToSet(const Lhs & lhs,LhsSet * lhs_set)174 Nonterm Ir::AddToSet(const Lhs& lhs, LhsSet* lhs_set) {
175   const int lhs_set_size = lhs_set->size();
176   Nonterm shareable_nonterm = lhs.nonterminal;
177   for (int i = 0; i < lhs_set_size; i++) {
178     Lhs* candidate = &lhs_set->at(i);
179 
180     // Exact match, just reuse rule.
181     if (lhs == *candidate) {
182       return candidate->nonterminal;
183     }
184 
185     // Cannot reuse unshareable ids.
186     if (nonshareable_.find(candidate->nonterminal) != nonshareable_.end() ||
187         nonshareable_.find(lhs.nonterminal) != nonshareable_.end()) {
188       continue;
189     }
190 
191     // Cannot reuse id if the preconditions are different.
192     if (!(lhs.preconditions == candidate->preconditions)) {
193       continue;
194     }
195 
196     // If the nonterminal is already defined, it must match for sharing.
197     if (lhs.nonterminal != kUnassignedNonterm &&
198         lhs.nonterminal != candidate->nonterminal) {
199       continue;
200     }
201 
202     // Check whether the callbacks match.
203     if (lhs.callback == candidate->callback) {
204       return candidate->nonterminal;
205     }
206 
207     // We can reuse if one of the output callbacks is not used.
208     if (lhs.callback.id == kNoCallback) {
209       return candidate->nonterminal;
210     } else if (candidate->callback.id == kNoCallback) {
211       // Old entry has no output callback, which is redundant now.
212       candidate->callback = lhs.callback;
213       return candidate->nonterminal;
214     }
215 
216     // We can share the nonterminal, but we need to
217     // add a new output callback. Defer this as we might find a shareable
218     // nonterminal first.
219     shareable_nonterm = candidate->nonterminal;
220   }
221 
222   // We didn't find a redundant entry, so create a new one.
223   shareable_nonterm = DefineNonterminal(shareable_nonterm);
224   lhs_set->push_back(Lhs{shareable_nonterm, lhs.callback, lhs.preconditions});
225   return shareable_nonterm;
226 }
227 
Add(const Lhs & lhs,const std::string & terminal,const bool case_sensitive,const int shard)228 Nonterm Ir::Add(const Lhs& lhs, const std::string& terminal,
229                 const bool case_sensitive, const int shard) {
230   TC3_CHECK_LT(shard, shards_.size());
231   if (case_sensitive) {
232     return AddRule(lhs, terminal, &shards_[shard].terminal_rules);
233   } else {
234     return AddRule(lhs, terminal, &shards_[shard].lowercase_terminal_rules);
235   }
236 }
237 
Add(const Lhs & lhs,const std::vector<Nonterm> & rhs,const int shard)238 Nonterm Ir::Add(const Lhs& lhs, const std::vector<Nonterm>& rhs,
239                 const int shard) {
240   // Add a new unary rule.
241   if (rhs.size() == 1) {
242     return Add(lhs, rhs.front(), shard);
243   }
244 
245   // Add a chain of (rhs.size() - 1) binary rules.
246   Nonterm prev = rhs.front();
247   for (int i = 1; i < rhs.size() - 1; i++) {
248     prev = Add(kUnassignedNonterm, prev, rhs[i], shard);
249   }
250   return Add(lhs, prev, rhs.back(), shard);
251 }
252 
AddRegex(Nonterm lhs,const std::string & regex_pattern)253 Nonterm Ir::AddRegex(Nonterm lhs, const std::string& regex_pattern) {
254   lhs = DefineNonterminal(lhs);
255   regex_rules_.emplace_back(regex_pattern, lhs);
256   return lhs;
257 }
258 
AddAnnotation(const Nonterm lhs,const std::string & annotation)259 void Ir::AddAnnotation(const Nonterm lhs, const std::string& annotation) {
260   annotations_.emplace_back(annotation, lhs);
261 }
262 
263 // Serializes the terminal rules table.
SerializeTerminalRules(RulesSetT * rules_set,std::vector<std::unique_ptr<RulesSet_::RulesT>> * rules_shards) const264 void Ir::SerializeTerminalRules(
265     RulesSetT* rules_set,
266     std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const {
267   // Use common pool for all terminals.
268   struct TerminalEntry {
269     std::string terminal;
270     int set_index;
271     int index;
272     Ir::LhsSet lhs_set;
273   };
274   std::vector<TerminalEntry> terminal_rules;
275 
276   // Merge all terminals into a common pool.
277   // We want to use one common pool, but still need to track which set they
278   // belong to.
279   std::vector<const std::unordered_map<std::string, Ir::LhsSet>*>
280       terminal_rules_sets;
281   std::vector<RulesSet_::Rules_::TerminalRulesMapT*> rules_maps;
282   terminal_rules_sets.reserve(2 * shards_.size());
283   rules_maps.reserve(terminal_rules_sets.size());
284   for (int i = 0; i < shards_.size(); i++) {
285     terminal_rules_sets.push_back(&shards_[i].terminal_rules);
286     terminal_rules_sets.push_back(&shards_[i].lowercase_terminal_rules);
287     rules_shards->at(i)->terminal_rules.reset(
288         new RulesSet_::Rules_::TerminalRulesMapT());
289     rules_shards->at(i)->lowercase_terminal_rules.reset(
290         new RulesSet_::Rules_::TerminalRulesMapT());
291     rules_maps.push_back(rules_shards->at(i)->terminal_rules.get());
292     rules_maps.push_back(rules_shards->at(i)->lowercase_terminal_rules.get());
293   }
294   for (int i = 0; i < terminal_rules_sets.size(); i++) {
295     for (const auto& it : *terminal_rules_sets[i]) {
296       terminal_rules.push_back(
297           TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second});
298     }
299   }
300   std::sort(terminal_rules.begin(), terminal_rules.end(),
301             [](const TerminalEntry& a, const TerminalEntry& b) {
302               return a.terminal < b.terminal;
303             });
304 
305   // Index the entries in sorted order.
306   std::vector<int> index(terminal_rules_sets.size(), 0);
307   for (int i = 0; i < terminal_rules.size(); i++) {
308     terminal_rules[i].index = index[terminal_rules[i].set_index]++;
309   }
310 
311   // We store the terminal strings sorted into a buffer and keep offsets into
312   // that buffer. In this way, we don't need extra space for terminals that are
313   // suffixes of others.
314 
315   // Find terminals that are a suffix of others, O(n^2) algorithm.
316   constexpr int kInvalidIndex = -1;
317   std::vector<int> suffix(terminal_rules.size(), kInvalidIndex);
318   for (int i = 0; i < terminal_rules.size(); i++) {
319     const StringPiece terminal(terminal_rules[i].terminal);
320 
321     // Check whether the ith terminal is a suffix of another.
322     for (int j = 0; j < terminal_rules.size(); j++) {
323       if (i == j) {
324         continue;
325       }
326       if (StringPiece(terminal_rules[j].terminal).EndsWith(terminal)) {
327         // If both terminals are the same keep the first.
328         // This avoids cyclic dependencies.
329         // This can happen if multiple shards use same terminals, such as
330         // punctuation.
331         if (terminal_rules[j].terminal.size() == terminal.size() && j < i) {
332           continue;
333         }
334         suffix[i] = j;
335         break;
336       }
337     }
338   }
339 
340   rules_set->terminals = "";
341 
342   for (int i = 0; i < terminal_rules_sets.size(); i++) {
343     rules_maps[i]->terminal_offsets.resize(terminal_rules_sets[i]->size());
344     rules_maps[i]->max_terminal_length = 0;
345     rules_maps[i]->min_terminal_length = std::numeric_limits<int>::max();
346   }
347 
348   for (int i = 0; i < terminal_rules.size(); i++) {
349     const TerminalEntry& entry = terminal_rules[i];
350 
351     // Update bounds.
352     rules_maps[entry.set_index]->min_terminal_length =
353         std::min(rules_maps[entry.set_index]->min_terminal_length,
354                  static_cast<int>(entry.terminal.size()));
355     rules_maps[entry.set_index]->max_terminal_length =
356         std::max(rules_maps[entry.set_index]->max_terminal_length,
357                  static_cast<int>(entry.terminal.size()));
358 
359     // Only include terminals that are not suffixes of others.
360     if (suffix[i] != kInvalidIndex) {
361       continue;
362     }
363 
364     rules_maps[entry.set_index]->terminal_offsets[entry.index] =
365         rules_set->terminals.length();
366     rules_set->terminals += entry.terminal + '\0';
367   }
368 
369   // Store just an offset into the existing terminal data for the terminals
370   // that are suffixes of others.
371   for (int i = 0; i < terminal_rules.size(); i++) {
372     int canonical_index = i;
373     if (suffix[canonical_index] == kInvalidIndex) {
374       continue;
375     }
376 
377     // Find the overlapping string that was included in the data.
378     while (suffix[canonical_index] != kInvalidIndex) {
379       canonical_index = suffix[canonical_index];
380     }
381 
382     const TerminalEntry& entry = terminal_rules[i];
383     const TerminalEntry& canonical_entry = terminal_rules[canonical_index];
384 
385     // The offset is the offset of the overlapping string and the offset within
386     // that string.
387     rules_maps[entry.set_index]->terminal_offsets[entry.index] =
388         rules_maps[canonical_entry.set_index]
389             ->terminal_offsets[canonical_entry.index] +
390         (canonical_entry.terminal.length() - entry.terminal.length());
391   }
392 
393   for (const TerminalEntry& entry : terminal_rules) {
394     rules_maps[entry.set_index]->lhs_set_index.push_back(
395         AddLhsSet(entry.lhs_set, rules_set));
396   }
397 }
398 
Serialize(const bool include_debug_information,RulesSetT * output) const399 void Ir::Serialize(const bool include_debug_information,
400                    RulesSetT* output) const {
401   // Add information about predefined nonterminal classes.
402   output->nonterminals.reset(new RulesSet_::NonterminalsT);
403   output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm);
404   output->nonterminals->end_nt = GetNonterminalForName(kEndNonterm);
405   output->nonterminals->wordbreak_nt = GetNonterminalForName(kWordBreakNonterm);
406   output->nonterminals->token_nt = GetNonterminalForName(kTokenNonterm);
407   output->nonterminals->uppercase_token_nt =
408       GetNonterminalForName(kUppercaseTokenNonterm);
409   output->nonterminals->digits_nt = GetNonterminalForName(kDigitsNonterm);
410   for (int i = 1; i <= kMaxNDigitsNontermLength; i++) {
411     if (const Nonterm n_digits_nt =
412             GetNonterminalForName(strings::StringPrintf(kNDigitsNonterm, i))) {
413       output->nonterminals->n_digits_nt.resize(i, kUnassignedNonterm);
414       output->nonterminals->n_digits_nt[i - 1] = n_digits_nt;
415     }
416   }
417   for (const auto& [annotation, annotation_nt] : annotations_) {
418     output->nonterminals->annotation_nt.emplace_back(
419         new RulesSet_::Nonterminals_::AnnotationNtEntryT);
420     output->nonterminals->annotation_nt.back()->key = annotation;
421     output->nonterminals->annotation_nt.back()->value = annotation_nt;
422   }
423   SortForBinarySearchLookup(&output->nonterminals->annotation_nt);
424 
425   if (include_debug_information) {
426     output->debug_information.reset(new RulesSet_::DebugInformationT);
427     // Keep original non-terminal names.
428     for (const auto& it : nonterminal_names_) {
429       output->debug_information->nonterminal_names.emplace_back(
430           new RulesSet_::DebugInformation_::NonterminalNamesEntryT);
431       output->debug_information->nonterminal_names.back()->key = it.first;
432       output->debug_information->nonterminal_names.back()->value = it.second;
433     }
434     SortForBinarySearchLookup(&output->debug_information->nonterminal_names);
435   }
436 
437   // Add regex rules.
438   std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
439   for (auto [pattern, lhs] : regex_rules_) {
440     output->regex_annotator.emplace_back(new RulesSet_::RegexAnnotatorT);
441     output->regex_annotator.back()->compressed_pattern.reset(
442         new CompressedBufferT);
443     compressor->Compress(
444         pattern, output->regex_annotator.back()->compressed_pattern.get());
445     output->regex_annotator.back()->nonterminal = lhs;
446   }
447 
448   // Serialize the unary and binary rules.
449   for (int i = 0; i < shards_.size(); i++) {
450     output->rules.emplace_back(std::make_unique<RulesSet_::RulesT>());
451     RulesSet_::RulesT* rules = output->rules.back().get();
452     for (const Locale& shard_locale : locale_shard_map_.GetLocales(i)) {
453       if (shard_locale.IsValid()) {
454         // Check if the language is set to all i.e. '*' which is a special, to
455         // make it consistent with device side parser here instead of filling
456         // the all locale leave the language tag list empty
457         rules->locale.emplace_back(
458             std::make_unique<libtextclassifier3::LanguageTagT>());
459         libtextclassifier3::LanguageTagT* language_tag =
460             rules->locale.back().get();
461         language_tag->language = shard_locale.Language();
462         language_tag->region = shard_locale.Region();
463         language_tag->script = shard_locale.Script();
464       }
465     }
466 
467     // Serialize the unary rules.
468     SerializeUnaryRulesShard(shards_[i].unary_rules, output, rules);
469     // Serialize the binary rules.
470     SerializeBinaryRulesShard(shards_[i].binary_rules, output, rules);
471   }
472   // Serialize the terminal rules.
473   // We keep the rules separate by shard but merge the actual terminals into
474   // one shared string pool to most effectively exploit reuse.
475   SerializeTerminalRules(output, &output->rules);
476 }
477 
SerializeAsFlatbuffer(const bool include_debug_information) const478 std::string Ir::SerializeAsFlatbuffer(
479     const bool include_debug_information) const {
480   RulesSetT output;
481   Serialize(include_debug_information, &output);
482   flatbuffers::FlatBufferBuilder builder;
483   builder.Finish(RulesSet::Pack(builder, &output));
484   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
485                      builder.GetSize());
486 }
487 
488 }  // namespace libtextclassifier3::grammar
489