1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/datetime/parser.h"
18 
19 #include <set>
20 #include <unordered_set>
21 
22 #include "annotator/datetime/extractor.h"
23 #include "utils/calendar/calendar.h"
24 #include "utils/i18n/locale.h"
25 #include "utils/strings/split.h"
26 #include "utils/zlib/zlib_regex.h"
27 
28 namespace libtextclassifier3 {
Instance(const DatetimeModel * model,const UniLib & unilib,const CalendarLib & calendarlib,ZlibDecompressor * decompressor)29 std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
30     const DatetimeModel* model, const UniLib& unilib,
31     const CalendarLib& calendarlib, ZlibDecompressor* decompressor) {
32   std::unique_ptr<DatetimeParser> result(
33       new DatetimeParser(model, unilib, calendarlib, decompressor));
34   if (!result->initialized_) {
35     result.reset();
36   }
37   return result;
38 }
39 
DatetimeParser(const DatetimeModel * model,const UniLib & unilib,const CalendarLib & calendarlib,ZlibDecompressor * decompressor)40 DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
41                                const CalendarLib& calendarlib,
42                                ZlibDecompressor* decompressor)
43     : unilib_(unilib), calendarlib_(calendarlib) {
44   initialized_ = false;
45 
46   if (model == nullptr) {
47     return;
48   }
49 
50   if (model->patterns() != nullptr) {
51     for (const DatetimeModelPattern* pattern : *model->patterns()) {
52       if (pattern->regexes()) {
53         for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
54           std::unique_ptr<UniLib::RegexPattern> regex_pattern =
55               UncompressMakeRegexPattern(
56                   unilib, regex->pattern(), regex->compressed_pattern(),
57                   model->lazy_regex_compilation(), decompressor);
58           if (!regex_pattern) {
59             TC3_LOG(ERROR) << "Couldn't create rule pattern.";
60             return;
61           }
62           rules_.push_back({std::move(regex_pattern), regex, pattern});
63           if (pattern->locales()) {
64             for (int locale : *pattern->locales()) {
65               locale_to_rules_[locale].push_back(rules_.size() - 1);
66             }
67           }
68         }
69       }
70     }
71   }
72 
73   if (model->extractors() != nullptr) {
74     for (const DatetimeModelExtractor* extractor : *model->extractors()) {
75       std::unique_ptr<UniLib::RegexPattern> regex_pattern =
76           UncompressMakeRegexPattern(
77               unilib, extractor->pattern(), extractor->compressed_pattern(),
78               model->lazy_regex_compilation(), decompressor);
79       if (!regex_pattern) {
80         TC3_LOG(ERROR) << "Couldn't create extractor pattern";
81         return;
82       }
83       extractor_rules_.push_back(std::move(regex_pattern));
84 
85       if (extractor->locales()) {
86         for (int locale : *extractor->locales()) {
87           type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
88               extractor_rules_.size() - 1;
89         }
90       }
91     }
92   }
93 
94   if (model->locales() != nullptr) {
95     for (int i = 0; i < model->locales()->Length(); ++i) {
96       locale_string_to_id_[model->locales()->Get(i)->str()] = i;
97     }
98   }
99 
100   if (model->default_locales() != nullptr) {
101     for (const int locale : *model->default_locales()) {
102       default_locale_ids_.push_back(locale);
103     }
104   }
105 
106   use_extractors_for_locating_ = model->use_extractors_for_locating();
107   generate_alternative_interpretations_when_ambiguous_ =
108       model->generate_alternative_interpretations_when_ambiguous();
109 
110   initialized_ = true;
111 }
112 
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const113 bool DatetimeParser::Parse(
114     const std::string& input, const int64 reference_time_ms_utc,
115     const std::string& reference_timezone, const std::string& locales,
116     ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
117     std::vector<DatetimeParseResultSpan>* results) const {
118   return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
119                reference_time_ms_utc, reference_timezone, locales, mode,
120                annotation_usecase, anchor_start_end, results);
121 }
122 
FindSpansUsingLocales(const std::vector<int> & locale_ids,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,const std::string & reference_locale,std::unordered_set<int> * executed_rules,std::vector<DatetimeParseResultSpan> * found_spans) const123 bool DatetimeParser::FindSpansUsingLocales(
124     const std::vector<int>& locale_ids, const UnicodeText& input,
125     const int64 reference_time_ms_utc, const std::string& reference_timezone,
126     ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
127     const std::string& reference_locale,
128     std::unordered_set<int>* executed_rules,
129     std::vector<DatetimeParseResultSpan>* found_spans) const {
130   for (const int locale_id : locale_ids) {
131     auto rules_it = locale_to_rules_.find(locale_id);
132     if (rules_it == locale_to_rules_.end()) {
133       continue;
134     }
135 
136     for (const int rule_id : rules_it->second) {
137       // Skip rules that were already executed in previous locales.
138       if (executed_rules->find(rule_id) != executed_rules->end()) {
139         continue;
140       }
141 
142       if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
143            (1 << annotation_usecase)) == 0) {
144         continue;
145       }
146 
147       if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
148         continue;
149       }
150 
151       executed_rules->insert(rule_id);
152 
153       if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
154                          reference_timezone, reference_locale, locale_id,
155                          anchor_start_end, found_spans)) {
156         return false;
157       }
158     }
159   }
160   return true;
161 }
162 
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const163 bool DatetimeParser::Parse(
164     const UnicodeText& input, const int64 reference_time_ms_utc,
165     const std::string& reference_timezone, const std::string& locales,
166     ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
167     std::vector<DatetimeParseResultSpan>* results) const {
168   std::vector<DatetimeParseResultSpan> found_spans;
169   std::unordered_set<int> executed_rules;
170   std::string reference_locale;
171   const std::vector<int> requested_locales =
172       ParseAndExpandLocales(locales, &reference_locale);
173   if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
174                              reference_timezone, mode, annotation_usecase,
175                              anchor_start_end, reference_locale,
176                              &executed_rules, &found_spans)) {
177     return false;
178   }
179 
180   std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
181   indexed_found_spans.reserve(found_spans.size());
182   for (int i = 0; i < found_spans.size(); i++) {
183     indexed_found_spans.push_back({found_spans[i], i});
184   }
185 
186   // Resolve conflicts by always picking the longer span and breaking ties by
187   // selecting the earlier entry in the list for a given locale.
188   std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
189             [](const std::pair<DatetimeParseResultSpan, int>& a,
190                const std::pair<DatetimeParseResultSpan, int>& b) {
191               if ((a.first.span.second - a.first.span.first) !=
192                   (b.first.span.second - b.first.span.first)) {
193                 return (a.first.span.second - a.first.span.first) >
194                        (b.first.span.second - b.first.span.first);
195               } else {
196                 return a.second < b.second;
197               }
198             });
199 
200   found_spans.clear();
201   for (auto& span_index_pair : indexed_found_spans) {
202     found_spans.push_back(span_index_pair.first);
203   }
204 
205   std::set<int, std::function<bool(int, int)>> chosen_indices_set(
206       [&found_spans](int a, int b) {
207         return found_spans[a].span.first < found_spans[b].span.first;
208       });
209   for (int i = 0; i < found_spans.size(); ++i) {
210     if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
211       chosen_indices_set.insert(i);
212       results->push_back(found_spans[i]);
213     }
214   }
215 
216   return true;
217 }
218 
HandleParseMatch(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResultSpan> * result) const219 bool DatetimeParser::HandleParseMatch(
220     const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
221     int64 reference_time_ms_utc, const std::string& reference_timezone,
222     const std::string& reference_locale, int locale_id,
223     std::vector<DatetimeParseResultSpan>* result) const {
224   int status = UniLib::RegexMatcher::kNoError;
225   const int start = matcher.Start(&status);
226   if (status != UniLib::RegexMatcher::kNoError) {
227     return false;
228   }
229 
230   const int end = matcher.End(&status);
231   if (status != UniLib::RegexMatcher::kNoError) {
232     return false;
233   }
234 
235   DatetimeParseResultSpan parse_result;
236   std::vector<DatetimeParseResult> alternatives;
237   if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
238                        reference_locale, locale_id, &alternatives,
239                        &parse_result.span)) {
240     return false;
241   }
242 
243   if (!use_extractors_for_locating_) {
244     parse_result.span = {start, end};
245   }
246 
247   if (parse_result.span.first != kInvalidIndex &&
248       parse_result.span.second != kInvalidIndex) {
249     parse_result.target_classification_score =
250         rule.pattern->target_classification_score();
251     parse_result.priority_score = rule.pattern->priority_score();
252 
253     for (DatetimeParseResult& alternative : alternatives) {
254       parse_result.data.push_back(alternative);
255     }
256   }
257   result->push_back(parse_result);
258   return true;
259 }
260 
ParseWithRule(const CompiledRule & rule,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,const int locale_id,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * result) const261 bool DatetimeParser::ParseWithRule(
262     const CompiledRule& rule, const UnicodeText& input,
263     const int64 reference_time_ms_utc, const std::string& reference_timezone,
264     const std::string& reference_locale, const int locale_id,
265     bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
266   std::unique_ptr<UniLib::RegexMatcher> matcher =
267       rule.compiled_regex->Matcher(input);
268   int status = UniLib::RegexMatcher::kNoError;
269   if (anchor_start_end) {
270     if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
271       if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
272                             reference_timezone, reference_locale, locale_id,
273                             result)) {
274         return false;
275       }
276     }
277   } else {
278     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
279       if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
280                             reference_timezone, reference_locale, locale_id,
281                             result)) {
282         return false;
283       }
284     }
285   }
286   return true;
287 }
288 
ParseAndExpandLocales(const std::string & locales,std::string * reference_locale) const289 std::vector<int> DatetimeParser::ParseAndExpandLocales(
290     const std::string& locales, std::string* reference_locale) const {
291   std::vector<StringPiece> split_locales = strings::Split(locales, ',');
292   if (!split_locales.empty()) {
293     *reference_locale = split_locales[0].ToString();
294   } else {
295     *reference_locale = "";
296   }
297 
298   std::vector<int> result;
299   for (const StringPiece& locale_str : split_locales) {
300     auto locale_it = locale_string_to_id_.find(locale_str.ToString());
301     if (locale_it != locale_string_to_id_.end()) {
302       result.push_back(locale_it->second);
303     }
304 
305     const Locale locale = Locale::FromBCP47(locale_str.ToString());
306     if (!locale.IsValid()) {
307       continue;
308     }
309 
310     const std::string language = locale.Language();
311     const std::string script = locale.Script();
312     const std::string region = locale.Region();
313 
314     // First, try adding *-region locale.
315     if (!region.empty()) {
316       locale_it = locale_string_to_id_.find("*-" + region);
317       if (locale_it != locale_string_to_id_.end()) {
318         result.push_back(locale_it->second);
319       }
320     }
321     // Second, try adding language-script-* locale.
322     if (!script.empty()) {
323       locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
324       if (locale_it != locale_string_to_id_.end()) {
325         result.push_back(locale_it->second);
326       }
327     }
328     // Third, try adding language-* locale.
329     if (!language.empty()) {
330       locale_it = locale_string_to_id_.find(language + "-*");
331       if (locale_it != locale_string_to_id_.end()) {
332         result.push_back(locale_it->second);
333       }
334     }
335   }
336 
337   // Add the default locales if they haven't been added already.
338   const std::unordered_set<int> result_set(result.begin(), result.end());
339   for (const int default_locale_id : default_locale_ids_) {
340     if (result_set.find(default_locale_id) == result_set.end()) {
341       result.push_back(default_locale_id);
342     }
343   }
344 
345   return result;
346 }
347 
FillInterpretations(const DateParseData & parse,std::vector<DateParseData> * interpretations) const348 void DatetimeParser::FillInterpretations(
349     const DateParseData& parse,
350     std::vector<DateParseData>* interpretations) const {
351   DatetimeGranularity granularity = calendarlib_.GetGranularity(parse);
352 
353   DateParseData modified_parse(parse);
354   // If the relation field is not set, but relation_type field *is*, assume
355   // the relation field is NEXT_OR_SAME. This is necessary to handle e.g.
356   // "monday 3pm" (otherwise only "this monday 3pm" would work).
357   if (!(modified_parse.field_set_mask &
358         DateParseData::Fields::RELATION_FIELD) &&
359       (modified_parse.field_set_mask &
360        DateParseData::Fields::RELATION_TYPE_FIELD)) {
361     modified_parse.relation = DateParseData::Relation::NEXT_OR_SAME;
362     modified_parse.field_set_mask |= DateParseData::Fields::RELATION_FIELD;
363   }
364 
365   // Multiple interpretations of ambiguous datetime expressions are generated
366   // here.
367   if (granularity > DatetimeGranularity::GRANULARITY_DAY &&
368       (modified_parse.field_set_mask & DateParseData::Fields::HOUR_FIELD) &&
369       modified_parse.hour <= 12 &&
370       !(modified_parse.field_set_mask & DateParseData::Fields::AMPM_FIELD)) {
371     // If it's not clear if the time is AM or PM, generate all variants.
372     interpretations->push_back(modified_parse);
373     interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
374     interpretations->back().ampm = DateParseData::AMPM::AM;
375 
376     interpretations->push_back(modified_parse);
377     interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
378     interpretations->back().ampm = DateParseData::AMPM::PM;
379   } else {
380     // Otherwise just generate 1 variant.
381     interpretations->push_back(modified_parse);
382   }
383   // TODO(zilka): Add support for generating alternatives for "monday" -> "this
384   // monday", "next monday", "last monday". The previous implementation did not
385   // work as expected, because didn't work correctly for this/previous day of
386   // week, and resulted sometimes results in the same date being proposed.
387 }
388 
ExtractDatetime(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResult> * results,CodepointSpan * result_span) const389 bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
390                                      const UniLib::RegexMatcher& matcher,
391                                      const int64 reference_time_ms_utc,
392                                      const std::string& reference_timezone,
393                                      const std::string& reference_locale,
394                                      int locale_id,
395                                      std::vector<DatetimeParseResult>* results,
396                                      CodepointSpan* result_span) const {
397   DateParseData parse;
398   DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
399                               extractor_rules_,
400                               type_and_locale_to_extractor_rule_);
401   if (!extractor.Extract(&parse, result_span)) {
402     return false;
403   }
404 
405   std::vector<DateParseData> interpretations;
406   if (generate_alternative_interpretations_when_ambiguous_) {
407     FillInterpretations(parse, &interpretations);
408   } else {
409     interpretations.push_back(parse);
410   }
411 
412   results->reserve(results->size() + interpretations.size());
413   for (const DateParseData& interpretation : interpretations) {
414     DatetimeParseResult result;
415     if (!calendarlib_.InterpretParseData(
416             interpretation, reference_time_ms_utc, reference_timezone,
417             reference_locale, &(result.time_ms_utc), &(result.granularity))) {
418       return false;
419     }
420     results->push_back(result);
421   }
422   return true;
423 }
424 
425 }  // namespace libtextclassifier3
426