1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/utils/rules.h"
18 
19 #include <set>
20 
21 #include "utils/grammar/utils/ir.h"
22 #include "utils/strings/append.h"
23 #include "utils/strings/stringpiece.h"
24 
25 namespace libtextclassifier3::grammar {
26 namespace {
27 
28 // Returns whether a nonterminal is a pre-defined one.
IsPredefinedNonterminal(const std::string & nonterminal_name)29 bool IsPredefinedNonterminal(const std::string& nonterminal_name) {
30   if (nonterminal_name == kStartNonterm || nonterminal_name == kEndNonterm ||
31       nonterminal_name == kTokenNonterm || nonterminal_name == kDigitsNonterm ||
32       nonterminal_name == kWordBreakNonterm) {
33     return true;
34   }
35   for (int digits = 1; digits <= kMaxNDigitsNontermLength; digits++) {
36     if (nonterminal_name == strings::StringPrintf(kNDigitsNonterm, digits)) {
37       return true;
38     }
39   }
40   return false;
41 }
42 
43 // Gets an assigned Nonterm for a nonterminal or kUnassignedNonterm if not yet
44 // assigned.
GetAssignedIdForNonterminal(const int nonterminal,const std::unordered_map<int,Nonterm> & assignment)45 Nonterm GetAssignedIdForNonterminal(
46     const int nonterminal, const std::unordered_map<int, Nonterm>& assignment) {
47   const auto it = assignment.find(nonterminal);
48   if (it == assignment.end()) {
49     return kUnassignedNonterm;
50   }
51   return it->second;
52 }
53 
54 // Checks whether all the nonterminals in the rhs of a rule have already been
55 // assigned Nonterm values.
IsRhsAssigned(const Rules::Rule & rule,const std::unordered_map<int,Nonterm> & nonterminals)56 bool IsRhsAssigned(const Rules::Rule& rule,
57                    const std::unordered_map<int, Nonterm>& nonterminals) {
58   for (const Rules::RhsElement& element : rule.rhs) {
59     // Terminals are always considered assigned, check only for non-terminals.
60     if (element.is_terminal) {
61       continue;
62     }
63     if (GetAssignedIdForNonterminal(element.nonterminal, nonterminals) ==
64         kUnassignedNonterm) {
65       return false;
66     }
67   }
68 
69   // Check that all parts of an exclusion are defined.
70   if (rule.callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
71     if (GetAssignedIdForNonterminal(rule.callback_param, nonterminals) ==
72         kUnassignedNonterm) {
73       return false;
74     }
75   }
76 
77   return true;
78 }
79 
80 // Lowers a single high-level rule down into the intermediate representation.
LowerRule(const int lhs_index,const Rules::Rule & rule,std::unordered_map<int,Nonterm> * nonterminals,Ir * ir)81 void LowerRule(const int lhs_index, const Rules::Rule& rule,
82                std::unordered_map<int, Nonterm>* nonterminals, Ir* ir) {
83   const CallbackId callback = rule.callback;
84   int64 callback_param = rule.callback_param;
85 
86   // Resolve id of excluded nonterminal in exclusion rules.
87   if (callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
88     callback_param = GetAssignedIdForNonterminal(callback_param, *nonterminals);
89     TC3_CHECK_NE(callback_param, kUnassignedNonterm);
90   }
91 
92   // Special case for terminal rules.
93   if (rule.rhs.size() == 1 && rule.rhs.front().is_terminal) {
94     (*nonterminals)[lhs_index] =
95         ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
96                         /*callback=*/{callback, callback_param},
97                         /*preconditions=*/{rule.max_whitespace_gap}},
98                 rule.rhs.front().terminal, rule.case_sensitive, rule.shard);
99     return;
100   }
101 
102   // Nonterminal rules.
103   std::vector<Nonterm> rhs_nonterms;
104   for (const Rules::RhsElement& element : rule.rhs) {
105     if (element.is_terminal) {
106       rhs_nonterms.push_back(ir->Add(Ir::Lhs{kUnassignedNonterm},
107                                      element.terminal, rule.case_sensitive,
108                                      rule.shard));
109     } else {
110       Nonterm nonterminal_id =
111           GetAssignedIdForNonterminal(element.nonterminal, *nonterminals);
112       TC3_CHECK_NE(nonterminal_id, kUnassignedNonterm);
113       rhs_nonterms.push_back(nonterminal_id);
114     }
115   }
116   (*nonterminals)[lhs_index] =
117       ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
118                       /*callback=*/{callback, callback_param},
119                       /*preconditions=*/{rule.max_whitespace_gap}},
120               rhs_nonterms, rule.shard);
121 }
122 // Check whether this component is a non-terminal.
IsNonterminal(StringPiece rhs_component)123 bool IsNonterminal(StringPiece rhs_component) {
124   return rhs_component[0] == '<' &&
125          rhs_component[rhs_component.size() - 1] == '>';
126 }
127 
128 // Sanity check for common typos -- '<' or '>' in a terminal.
ValidateTerminal(StringPiece rhs_component)129 void ValidateTerminal(StringPiece rhs_component) {
130   TC3_CHECK_EQ(rhs_component.find('<'), std::string::npos)
131       << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
132   TC3_CHECK_EQ(rhs_component.find('>'), std::string::npos)
133       << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
134   TC3_CHECK_EQ(rhs_component.find('?'), std::string::npos)
135       << "Rhs terminal `" << rhs_component << "` contains a question mark.";
136 }
137 
138 }  // namespace
139 
AddNonterminal(const std::string & nonterminal_name)140 int Rules::AddNonterminal(const std::string& nonterminal_name) {
141   std::string key = nonterminal_name;
142   auto alias_it = nonterminal_alias_.find(key);
143   if (alias_it != nonterminal_alias_.end()) {
144     key = alias_it->second;
145   }
146   auto it = nonterminal_names_.find(key);
147   if (it != nonterminal_names_.end()) {
148     return it->second;
149   }
150   const int index = nonterminals_.size();
151   nonterminals_.push_back(NontermInfo{key});
152   nonterminal_names_.insert(it, {key, index});
153   return index;
154 }
155 
AddNewNonterminal()156 int Rules::AddNewNonterminal() {
157   const int index = nonterminals_.size();
158   nonterminals_.push_back(NontermInfo{});
159   return index;
160 }
161 
AddAlias(const std::string & nonterminal_name,const std::string & alias)162 void Rules::AddAlias(const std::string& nonterminal_name,
163                      const std::string& alias) {
164 #ifndef TC3_USE_CXX14
165   TC3_CHECK_EQ(nonterminal_alias_.insert_or_assign(alias, nonterminal_name)
166                    .first->second,
167                nonterminal_name)
168       << "Cannot redefine alias: " << alias;
169 #else
170   nonterminal_alias_[alias] = nonterminal_name;
171   TC3_CHECK_EQ(nonterminal_alias_[alias], nonterminal_name)
172       << "Cannot redefine alias: " << alias;
173 #endif
174 }
175 
176 // Defines a nonterminal for an externally provided annotation.
AddAnnotation(const std::string & annotation_name)177 int Rules::AddAnnotation(const std::string& annotation_name) {
178   auto [it, inserted] =
179       annotation_nonterminals_.insert({annotation_name, nonterminals_.size()});
180   if (inserted) {
181     nonterminals_.push_back(NontermInfo{});
182   }
183   return it->second;
184 }
185 
BindAnnotation(const std::string & nonterminal_name,const std::string & annotation_name)186 void Rules::BindAnnotation(const std::string& nonterminal_name,
187                            const std::string& annotation_name) {
188   auto [_, inserted] = annotation_nonterminals_.insert(
189       {annotation_name, AddNonterminal(nonterminal_name)});
190   TC3_CHECK(inserted);
191 }
192 
IsNonterminalOfName(const RhsElement & element,const std::string & nonterminal) const193 bool Rules::IsNonterminalOfName(const RhsElement& element,
194                                 const std::string& nonterminal) const {
195   if (element.is_terminal) {
196     return false;
197   }
198   return (nonterminals_[element.nonterminal].name == nonterminal);
199 }
200 
201 // Note: For k optional components this creates 2^k rules, but it would be
202 // possible to be smarter about this and only use 2k rules instead.
203 // However that might be slower as it requires an extra rule firing at match
204 // time for every omitted optional element.
ExpandOptionals(const int lhs,const std::vector<RhsElement> & rhs,const CallbackId callback,const int64 callback_param,const int8 max_whitespace_gap,const bool case_sensitive,const int shard,std::vector<int>::const_iterator optional_element_indices,std::vector<int>::const_iterator optional_element_indices_end,std::vector<bool> * omit_these)205 void Rules::ExpandOptionals(
206     const int lhs, const std::vector<RhsElement>& rhs,
207     const CallbackId callback, const int64 callback_param,
208     const int8 max_whitespace_gap, const bool case_sensitive, const int shard,
209     std::vector<int>::const_iterator optional_element_indices,
210     std::vector<int>::const_iterator optional_element_indices_end,
211     std::vector<bool>* omit_these) {
212   if (optional_element_indices == optional_element_indices_end) {
213     // Nothing is optional, so just generate a rule.
214     Rule r;
215     for (uint32 i = 0; i < rhs.size(); i++) {
216       if (!omit_these->at(i)) {
217         r.rhs.push_back(rhs[i]);
218       }
219     }
220     r.callback = callback;
221     r.callback_param = callback_param;
222     r.max_whitespace_gap = max_whitespace_gap;
223     r.case_sensitive = case_sensitive;
224     r.shard = shard;
225     nonterminals_[lhs].rules.push_back(rules_.size());
226     rules_.push_back(r);
227     return;
228   }
229 
230   const int next_optional_part = *optional_element_indices;
231   ++optional_element_indices;
232 
233   // Recursive call 1: The optional part is omitted.
234   (*omit_these)[next_optional_part] = true;
235   ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
236                   case_sensitive, shard, optional_element_indices,
237                   optional_element_indices_end, omit_these);
238 
239   // Recursive call 2: The optional part is required.
240   (*omit_these)[next_optional_part] = false;
241   ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
242                   case_sensitive, shard, optional_element_indices,
243                   optional_element_indices_end, omit_these);
244 }
245 
ResolveAnchors(const std::vector<RhsElement> & rhs) const246 std::vector<Rules::RhsElement> Rules::ResolveAnchors(
247     const std::vector<RhsElement>& rhs) const {
248   if (rhs.size() <= 2) {
249     return rhs;
250   }
251   auto begin = rhs.begin();
252   auto end = rhs.end();
253   if (IsNonterminalOfName(rhs.front(), kStartNonterm) &&
254       IsNonterminalOfName(rhs[1], kFiller)) {
255     // Skip start anchor and filler.
256     begin += 2;
257   }
258   if (IsNonterminalOfName(rhs.back(), kEndNonterm) &&
259       IsNonterminalOfName(rhs[rhs.size() - 2], kFiller)) {
260     // Skip filler and end anchor.
261     end -= 2;
262   }
263   return std::vector<Rules::RhsElement>(begin, end);
264 }
265 
ResolveFillers(const std::vector<RhsElement> & rhs,int shard)266 std::vector<Rules::RhsElement> Rules::ResolveFillers(
267     const std::vector<RhsElement>& rhs, int shard) {
268   std::vector<RhsElement> result;
269   for (int i = 0; i < rhs.size();) {
270     if (i == rhs.size() - 1 || IsNonterminalOfName(rhs[i], kFiller) ||
271         rhs[i].is_optional || !IsNonterminalOfName(rhs[i + 1], kFiller)) {
272       result.push_back(rhs[i]);
273       i++;
274       continue;
275     }
276 
277     // We have the case:
278     // <a> <filler>
279     // rewrite as:
280     // <a_with_tokens> ::= <a>
281     // <a_with_tokens> ::= <a_with_tokens> <token>
282     const int with_tokens_nonterminal = AddNewNonterminal();
283     const RhsElement token(AddNonterminal(kTokenNonterm),
284                            /*is_optional=*/false);
285     if (rhs[i + 1].is_optional) {
286       // <a_with_tokens> ::= <a>
287       Add(with_tokens_nonterminal, {rhs[i]},
288           /*callback=*/kNoCallback,
289           /*callback_param=*/0,
290           /*max_whitespace_gap=*/-1,
291           /*case_sensitive=*/false, shard);
292     } else {
293       // <a_with_tokens> ::= <a> <token>
294       Add(with_tokens_nonterminal, {rhs[i], token},
295           /*callback=*/kNoCallback,
296           /*callback_param=*/0,
297           /*max_whitespace_gap=*/-1,
298           /*case_sensitive=*/false, shard);
299     }
300     // <a_with_tokens> ::= <a_with_tokens> <token>
301     const RhsElement with_tokens(with_tokens_nonterminal,
302                                  /*is_optional=*/false);
303     Add(with_tokens_nonterminal, {with_tokens, token},
304         /*callback=*/kNoCallback,
305         /*callback_param=*/0,
306         /*max_whitespace_gap=*/-1,
307         /*case_sensitive=*/false, shard);
308     result.push_back(with_tokens);
309     i += 2;
310   }
311   return result;
312 }
313 
OptimizeRhs(const std::vector<RhsElement> & rhs,int shard)314 std::vector<Rules::RhsElement> Rules::OptimizeRhs(
315     const std::vector<RhsElement>& rhs, int shard) {
316   return ResolveFillers(ResolveAnchors(rhs), shard);
317 }
318 
Add(const int lhs,const std::vector<RhsElement> & rhs,const CallbackId callback,const int64 callback_param,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)319 void Rules::Add(const int lhs, const std::vector<RhsElement>& rhs,
320                 const CallbackId callback, const int64 callback_param,
321                 const int8 max_whitespace_gap, const bool case_sensitive,
322                 const int shard) {
323   // Resolve anchors and fillers.
324   const std::vector optimized_rhs = OptimizeRhs(rhs);
325 
326   std::vector<int> optional_element_indices;
327   TC3_CHECK_LT(optional_element_indices.size(), optimized_rhs.size())
328       << "Rhs must contain at least one non-optional element.";
329   for (int i = 0; i < optimized_rhs.size(); i++) {
330     if (optimized_rhs[i].is_optional) {
331       optional_element_indices.push_back(i);
332     }
333   }
334   std::vector<bool> omit_these(optimized_rhs.size(), false);
335   ExpandOptionals(lhs, optimized_rhs, callback, callback_param,
336                   max_whitespace_gap, case_sensitive, shard,
337                   optional_element_indices.begin(),
338                   optional_element_indices.end(), &omit_these);
339 }
340 
Add(const std::string & lhs,const std::vector<std::string> & rhs,const CallbackId callback,const int64 callback_param,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)341 void Rules::Add(const std::string& lhs, const std::vector<std::string>& rhs,
342                 const CallbackId callback, const int64 callback_param,
343                 const int8 max_whitespace_gap, const bool case_sensitive,
344                 const int shard) {
345   TC3_CHECK(!rhs.empty()) << "Rhs cannot be empty (Lhs=" << lhs << ")";
346   TC3_CHECK(!IsPredefinedNonterminal(lhs));
347   std::vector<RhsElement> rhs_elements;
348   rhs_elements.reserve(rhs.size());
349   for (StringPiece rhs_component : rhs) {
350     // Check whether this component is optional.
351     bool is_optional = false;
352     if (rhs_component[rhs_component.size() - 1] == '?') {
353       rhs_component.RemoveSuffix(1);
354       is_optional = true;
355     }
356     // Check whether this component is a non-terminal.
357     if (IsNonterminal(rhs_component)) {
358       rhs_elements.push_back(
359           RhsElement(AddNonterminal(rhs_component.ToString()), is_optional));
360     } else {
361       // A terminal.
362       // Sanity check for common typos -- '<' or '>' in a terminal.
363       ValidateTerminal(rhs_component);
364       rhs_elements.push_back(RhsElement(rhs_component.ToString(), is_optional));
365     }
366   }
367   Add(AddNonterminal(lhs), rhs_elements, callback, callback_param,
368       max_whitespace_gap, case_sensitive, shard);
369 }
370 
AddWithExclusion(const std::string & lhs,const std::vector<std::string> & rhs,const std::string & excluded_nonterminal,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)371 void Rules::AddWithExclusion(const std::string& lhs,
372                              const std::vector<std::string>& rhs,
373                              const std::string& excluded_nonterminal,
374                              const int8 max_whitespace_gap,
375                              const bool case_sensitive, const int shard) {
376   Add(lhs, rhs,
377       /*callback=*/static_cast<CallbackId>(DefaultCallback::kExclusion),
378       /*callback_param=*/AddNonterminal(excluded_nonterminal),
379       max_whitespace_gap, case_sensitive, shard);
380 }
381 
AddAssertion(const std::string & lhs,const std::vector<std::string> & rhs,const bool negative,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)382 void Rules::AddAssertion(const std::string& lhs,
383                          const std::vector<std::string>& rhs,
384                          const bool negative, const int8 max_whitespace_gap,
385                          const bool case_sensitive, const int shard) {
386   Add(lhs, rhs,
387       /*callback=*/static_cast<CallbackId>(DefaultCallback::kAssertion),
388       /*callback_param=*/negative, max_whitespace_gap, case_sensitive, shard);
389 }
390 
AddValueMapping(const std::string & lhs,const std::vector<std::string> & rhs,const int64 value,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)391 void Rules::AddValueMapping(const std::string& lhs,
392                             const std::vector<std::string>& rhs,
393                             const int64 value, const int8 max_whitespace_gap,
394                             const bool case_sensitive, const int shard) {
395   Add(lhs, rhs,
396       /*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
397       /*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
398 }
399 
AddValueMapping(const int lhs,const std::vector<RhsElement> & rhs,int64 value,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)400 void Rules::AddValueMapping(const int lhs, const std::vector<RhsElement>& rhs,
401                             int64 value, const int8 max_whitespace_gap,
402                             const bool case_sensitive, const int shard) {
403   Add(lhs, rhs,
404       /*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
405       /*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
406 }
407 
AddRegex(const std::string & lhs,const std::string & regex_pattern)408 void Rules::AddRegex(const std::string& lhs, const std::string& regex_pattern) {
409   AddRegex(AddNonterminal(lhs), regex_pattern);
410 }
411 
AddRegex(int lhs,const std::string & regex_pattern)412 void Rules::AddRegex(int lhs, const std::string& regex_pattern) {
413   nonterminals_[lhs].regex_rules.push_back(regex_rules_.size());
414   regex_rules_.push_back(regex_pattern);
415 }
416 
UsesFillers() const417 bool Rules::UsesFillers() const {
418   for (const Rule& rule : rules_) {
419     for (const RhsElement& rhs_element : rule.rhs) {
420       if (IsNonterminalOfName(rhs_element, kFiller)) {
421         return true;
422       }
423     }
424   }
425   return false;
426 }
427 
Finalize(const std::set<std::string> & predefined_nonterminals) const428 Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
429   Ir rules(locale_shard_map_);
430   std::unordered_map<int, Nonterm> nonterminal_ids;
431 
432   // Pending rules to process.
433   std::set<std::pair<int, int>> scheduled_rules;
434 
435   // Define all used predefined nonterminals.
436   for (const auto& it : nonterminal_names_) {
437     if (IsPredefinedNonterminal(it.first) ||
438         predefined_nonterminals.find(it.first) !=
439             predefined_nonterminals.end()) {
440       nonterminal_ids[it.second] = rules.AddUnshareableNonterminal(it.first);
441     }
442   }
443 
444   // Assign (unmergeable) Nonterm values to any nonterminals that have
445   // multiple rules.
446   for (int i = 0; i < nonterminals_.size(); i++) {
447     const NontermInfo& nonterminal = nonterminals_[i];
448 
449     // Skip predefined nonterminals, they have already been assigned.
450     if (rules.GetNonterminalForName(nonterminal.name) != kUnassignedNonterm) {
451       continue;
452     }
453 
454     bool unmergeable =
455         (nonterminal.from_annotation || nonterminal.rules.size() > 1 ||
456          !nonterminal.regex_rules.empty());
457     for (const int rule_index : nonterminal.rules) {
458       // Schedule rule.
459       scheduled_rules.insert({i, rule_index});
460     }
461 
462     if (unmergeable) {
463       // Define unique nonterminal id.
464       nonterminal_ids[i] = rules.AddUnshareableNonterminal(nonterminal.name);
465     } else {
466       nonterminal_ids[i] = rules.AddNonterminal(nonterminal.name);
467     }
468 
469     // Define regex rules.
470     for (const int regex_rule : nonterminal.regex_rules) {
471       rules.AddRegex(nonterminal_ids[i], regex_rules_[regex_rule]);
472     }
473   }
474 
475   // Define annotations.
476   for (const auto& [annotation, nonterminal] : annotation_nonterminals_) {
477     rules.AddAnnotation(nonterminal_ids[nonterminal], annotation);
478   }
479 
480   // Check whether fillers are still referenced (if they couldn't get optimized
481   // away).
482   if (UsesFillers()) {
483     TC3_LOG(WARNING) << "Rules use fillers that couldn't be optimized, grammar "
484                         "matching performance might be impacted.";
485 
486     // Add a definition for the filler:
487     // <filler> = <token>
488     // <filler> = <token> <filler>
489     const Nonterm filler = rules.GetNonterminalForName(kFiller);
490     const Nonterm token =
491         rules.DefineNonterminal(rules.GetNonterminalForName(kTokenNonterm));
492     rules.Add(filler, token);
493     rules.Add(filler, std::vector<Nonterm>{token, filler});
494   }
495 
496   // Now, keep adding eligible rules (rules whose rhs is completely assigned)
497   // until we can't make any more progress.
498   // Note: The following code is quadratic in the worst case.
499   // This seems fine as this will only run as part of the compilation of the
500   // grammar rules during model assembly.
501   bool changed = true;
502   while (changed) {
503     changed = false;
504     for (auto nt_and_rule = scheduled_rules.begin();
505          nt_and_rule != scheduled_rules.end();) {
506       const Rule& rule = rules_[nt_and_rule->second];
507       if (IsRhsAssigned(rule, nonterminal_ids)) {
508         // Compile the rule.
509         LowerRule(/*lhs_index=*/nt_and_rule->first, rule, &nonterminal_ids,
510                   &rules);
511         scheduled_rules.erase(
512             nt_and_rule++);  // Iterator is advanced before erase.
513         changed = true;
514         break;
515       } else {
516         nt_and_rule++;
517       }
518     }
519   }
520   TC3_CHECK(scheduled_rules.empty());
521   return rules;
522 }
523 
524 }  // namespace libtextclassifier3::grammar
525