1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/datetime/grammar-parser.h"
18 
19 #include <set>
20 #include <unordered_set>
21 
22 #include "annotator/datetime/datetime-grounder.h"
23 #include "annotator/types.h"
24 #include "utils/grammar/analyzer.h"
25 #include "utils/grammar/evaluated-derivation.h"
26 #include "utils/grammar/parsing/derivation.h"
27 
28 using ::libtextclassifier3::grammar::EvaluatedDerivation;
29 using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
30 
31 namespace libtextclassifier3 {
32 
GrammarDatetimeParser(const grammar::Analyzer & analyzer,const DatetimeGrounder & datetime_grounder,const float target_classification_score,const float priority_score)33 GrammarDatetimeParser::GrammarDatetimeParser(
34     const grammar::Analyzer& analyzer,
35     const DatetimeGrounder& datetime_grounder,
36     const float target_classification_score, const float priority_score)
37     : analyzer_(analyzer),
38       datetime_grounder_(datetime_grounder),
39       target_classification_score_(target_classification_score),
40       priority_score_(priority_score) {}
41 
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const42 StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
43     const std::string& input, const int64 reference_time_ms_utc,
44     const std::string& reference_timezone, const LocaleList& locale_list,
45     ModeFlag mode, AnnotationUsecase annotation_usecase,
46     bool anchor_start_end) const {
47   return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
48                reference_time_ms_utc, reference_timezone, locale_list, mode,
49                annotation_usecase, anchor_start_end);
50 }
51 
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const52 StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
53     const UnicodeText& input, const int64 reference_time_ms_utc,
54     const std::string& reference_timezone, const LocaleList& locale_list,
55     ModeFlag mode, AnnotationUsecase annotation_usecase,
56     bool anchor_start_end) const {
57   std::vector<DatetimeParseResultSpan> results;
58   UnsafeArena arena(/*block_size=*/16 << 10);
59   std::vector<Locale> locales = locale_list.GetLocales();
60   // If the locale list is empty then datetime regex expression will still
61   // execute but in grammar based parser the rules are associated with local
62   // and engine will not run if the locale list is empty. In an unlikely
63   // scenario when locale is not mentioned fallback to en-*.
64   if (locales.empty()) {
65     locales.emplace_back(Locale::FromBCP47("en"));
66   }
67   TC3_ASSIGN_OR_RETURN(
68       const std::vector<EvaluatedDerivation> evaluated_derivations,
69       analyzer_.Parse(input, locales, &arena,
70                       /*deduplicate_derivations=*/false));
71 
72   std::vector<EvaluatedDerivation> valid_evaluated_derivations;
73   for (const EvaluatedDerivation& evaluated_derivation :
74        evaluated_derivations) {
75     if (evaluated_derivation.value) {
76       if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
77         const UngroundedDatetime* ungrounded_datetime =
78             evaluated_derivation.value->Table<UngroundedDatetime>();
79         if (datetime_grounder_.IsValidUngroundedDatetime(ungrounded_datetime)) {
80           valid_evaluated_derivations.emplace_back(evaluated_derivation);
81         }
82       }
83     }
84   }
85   valid_evaluated_derivations =
86       grammar::DeduplicateDerivations(valid_evaluated_derivations);
87   for (const EvaluatedDerivation& evaluated_derivation :
88        valid_evaluated_derivations) {
89     if (evaluated_derivation.value) {
90       if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
91         const UngroundedDatetime* ungrounded_datetime =
92             evaluated_derivation.value->Table<UngroundedDatetime>();
93         if ((ungrounded_datetime->annotation_usecases() &
94              (1 << annotation_usecase)) == 0) {
95           continue;
96         }
97         const StatusOr<std::vector<DatetimeParseResult>>&
98             datetime_parse_results = datetime_grounder_.Ground(
99                 reference_time_ms_utc, reference_timezone,
100                 locale_list.GetReferenceLocale(), ungrounded_datetime);
101         TC3_ASSIGN_OR_RETURN(
102             const std::vector<DatetimeParseResult>& parse_datetime,
103             datetime_parse_results);
104         DatetimeParseResultSpan datetime_parse_result_span;
105         datetime_parse_result_span.target_classification_score =
106             target_classification_score_;
107         datetime_parse_result_span.priority_score = priority_score_;
108         datetime_parse_result_span.data.reserve(parse_datetime.size());
109         datetime_parse_result_span.data.insert(
110             datetime_parse_result_span.data.end(), parse_datetime.begin(),
111             parse_datetime.end());
112         datetime_parse_result_span.span =
113             evaluated_derivation.parse_tree->codepoint_span;
114 
115         results.emplace_back(datetime_parse_result_span);
116       }
117     }
118   }
119   return results;
120 }
121 }  // namespace libtextclassifier3
122