1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
18 #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
19 
20 #include <vector>
21 
22 #include "utils/grammar/parsing/parse-tree.h"
23 
24 namespace libtextclassifier3::grammar {
25 
26 // A parse tree for a root rule.
27 struct Derivation {
28   const ParseTree* parse_tree;
29   int64 rule_id;
30 
31   // Checks that all assertions are fulfilled.
32   bool IsValid() const;
GetRuleIdDerivation33   int64 GetRuleId() const { return rule_id; }
GetParseTreeDerivation34   const ParseTree* GetParseTree() const { return parse_tree; }
35 };
36 
37 // Deduplicates rule derivations by containing overlap.
38 // The grammar system can output multiple candidates for optional parts.
39 // For example if a rule has an optional suffix, we
40 // will get two rule derivations when the suffix is present: one with and one
41 // without the suffix. We therefore deduplicate by containing overlap, viz. from
42 // two candidates we keep the longer one if it completely contains the shorter.
43 // This factory function works with any type T that extends Derivation.
44 template <typename T, typename std::enable_if<std::is_base_of<
45                           Derivation, T>::value>::type* = nullptr>
46 // std::vector<T> DeduplicateDerivations(const std::vector<T>& derivations);
DeduplicateDerivations(const std::vector<T> & derivations)47 std::vector<T> DeduplicateDerivations(const std::vector<T>& derivations) {
48   std::vector<T> sorted_candidates = derivations;
49 
50   std::stable_sort(sorted_candidates.begin(), sorted_candidates.end(),
51                    [](const T& a, const T& b) {
52                      // Sort by id.
53                      if (a.GetRuleId() != b.GetRuleId()) {
54                        return a.GetRuleId() < b.GetRuleId();
55                      }
56 
57                      // Sort by increasing start.
58                      if (a.GetParseTree()->codepoint_span.first !=
59                          b.GetParseTree()->codepoint_span.first) {
60                        return a.GetParseTree()->codepoint_span.first <
61                               b.GetParseTree()->codepoint_span.first;
62                      }
63 
64                      // Sort by decreasing end.
65                      return a.GetParseTree()->codepoint_span.second >
66                             b.GetParseTree()->codepoint_span.second;
67                    });
68 
69   // Deduplicate by overlap.
70   std::vector<T> result;
71   for (int i = 0; i < sorted_candidates.size(); i++) {
72     const T& candidate = sorted_candidates[i];
73     bool eliminated = false;
74 
75     // Due to the sorting above, the candidate can only be completely
76     // intersected by a match before it in the sorted order.
77     for (int j = i - 1; j >= 0; j--) {
78       if (sorted_candidates[j].rule_id != candidate.rule_id) {
79         break;
80       }
81       if (sorted_candidates[j].parse_tree->codepoint_span.first <=
82               candidate.parse_tree->codepoint_span.first &&
83           sorted_candidates[j].parse_tree->codepoint_span.second >=
84               candidate.parse_tree->codepoint_span.second) {
85         eliminated = true;
86         break;
87       }
88     }
89     if (!eliminated) {
90       result.push_back(candidate);
91     }
92   }
93   return result;
94 }
95 
96 // Deduplicates and validates rule derivations.
97 std::vector<Derivation> ValidDeduplicatedDerivations(
98     const std::vector<Derivation>& derivations);
99 
100 }  // namespace libtextclassifier3::grammar
101 
102 #endif  // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
103