1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/datetime/testing/base-parser-test.h"
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "utils/i18n/locale-list.h"
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 
27 using std::vector;
28 using testing::ElementsAreArray;
29 
30 namespace libtextclassifier3 {
31 
HasNoResult(const std::string & text,bool anchor_start_end,const std::string & timezone,AnnotationUsecase annotation_usecase)32 bool DateTimeParserTest::HasNoResult(const std::string& text,
33                                      bool anchor_start_end,
34                                      const std::string& timezone,
35                                      AnnotationUsecase annotation_usecase) {
36   StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
37       DatetimeParserForTests()->Parse(
38           text, 0, timezone, LocaleList::ParseFrom(/*locale_tags=*/""),
39           ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
40   if (!results_status.ok()) {
41     TC3_LOG(ERROR) << text;
42     TC3_CHECK(false);
43   }
44   return results_status.ValueOrDie().empty();
45 }
46 
ParsesCorrectly(const std::string & marked_text,const vector<int64> & expected_ms_utcs,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components,bool anchor_start_end,const std::string & timezone,const std::string & locales,AnnotationUsecase annotation_usecase)47 bool DateTimeParserTest::ParsesCorrectly(
48     const std::string& marked_text, const vector<int64>& expected_ms_utcs,
49     DatetimeGranularity expected_granularity,
50     vector<vector<DatetimeComponent>> datetime_components,
51     bool anchor_start_end, const std::string& timezone,
52     const std::string& locales, AnnotationUsecase annotation_usecase) {
53   const UnicodeText marked_text_unicode =
54       UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
55   auto brace_open_it =
56       std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
57   auto brace_end_it =
58       std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
59   TC3_CHECK(brace_open_it != marked_text_unicode.end());
60   TC3_CHECK(brace_end_it != marked_text_unicode.end());
61 
62   std::string text;
63   text +=
64       UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
65   text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
66   text += UnicodeText::UTF8Substring(std::next(brace_end_it),
67                                      marked_text_unicode.end());
68 
69   StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
70       DatetimeParserForTests()->Parse(
71           text, 0, timezone, LocaleList::ParseFrom(locales),
72           ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
73   if (!results_status.ok()) {
74     TC3_LOG(ERROR) << text;
75     TC3_CHECK(false);
76   }
77   // const std::vector<DatetimeParseResultSpan>& results =
78   //     results_status.ValueOrDie();
79   if (results_status.ValueOrDie().empty()) {
80     TC3_LOG(ERROR) << "No results.";
81     return false;
82   }
83 
84   const int expected_start_index =
85       std::distance(marked_text_unicode.begin(), brace_open_it);
86   // The -1 below is to account for the opening bracket character.
87   const int expected_end_index =
88       std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
89 
90   std::vector<DatetimeParseResultSpan> filtered_results;
91   for (const DatetimeParseResultSpan& result : results_status.ValueOrDie()) {
92     if (SpansOverlap(result.span, {expected_start_index, expected_end_index})) {
93       filtered_results.push_back(result);
94     }
95   }
96   std::vector<DatetimeParseResultSpan> expected{
97       {{expected_start_index, expected_end_index},
98        {},
99        /*target_classification_score=*/1.0,
100        /*priority_score=*/1.0}};
101   expected[0].data.resize(expected_ms_utcs.size());
102   for (int i = 0; i < expected_ms_utcs.size(); i++) {
103     expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
104                            datetime_components[i]};
105   }
106 
107   const bool matches =
108       testing::Matches(ElementsAreArray(expected))(filtered_results);
109   if (!matches) {
110     TC3_LOG(ERROR) << "Expected: " << expected[0];
111     if (filtered_results.empty()) {
112       TC3_LOG(ERROR) << "But got no results.";
113     }
114     TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
115   }
116 
117   return matches;
118 }
119 
ParsesCorrectly(const std::string & marked_text,const int64 expected_ms_utc,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components,bool anchor_start_end,const std::string & timezone,const std::string & locales,AnnotationUsecase annotation_usecase)120 bool DateTimeParserTest::ParsesCorrectly(
121     const std::string& marked_text, const int64 expected_ms_utc,
122     DatetimeGranularity expected_granularity,
123     vector<vector<DatetimeComponent>> datetime_components,
124     bool anchor_start_end, const std::string& timezone,
125     const std::string& locales, AnnotationUsecase annotation_usecase) {
126   return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
127                          expected_granularity, datetime_components,
128                          anchor_start_end, timezone, locales,
129                          annotation_usecase);
130 }
131 
ParsesCorrectlyGerman(const std::string & marked_text,const vector<int64> & expected_ms_utcs,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components)132 bool DateTimeParserTest::ParsesCorrectlyGerman(
133     const std::string& marked_text, const vector<int64>& expected_ms_utcs,
134     DatetimeGranularity expected_granularity,
135     vector<vector<DatetimeComponent>> datetime_components) {
136   return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
137                          datetime_components,
138                          /*anchor_start_end=*/false,
139                          /*timezone=*/"Europe/Zurich", /*locales=*/"de");
140 }
141 
ParsesCorrectlyGerman(const std::string & marked_text,const int64 expected_ms_utc,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components)142 bool DateTimeParserTest::ParsesCorrectlyGerman(
143     const std::string& marked_text, const int64 expected_ms_utc,
144     DatetimeGranularity expected_granularity,
145     vector<vector<DatetimeComponent>> datetime_components) {
146   return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
147                          datetime_components,
148                          /*anchor_start_end=*/false,
149                          /*timezone=*/"Europe/Zurich", /*locales=*/"de");
150 }
151 
ParsesCorrectlyChinese(const std::string & marked_text,const int64 expected_ms_utc,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components)152 bool DateTimeParserTest::ParsesCorrectlyChinese(
153     const std::string& marked_text, const int64 expected_ms_utc,
154     DatetimeGranularity expected_granularity,
155     vector<vector<DatetimeComponent>> datetime_components) {
156   return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
157                          datetime_components,
158                          /*anchor_start_end=*/false,
159                          /*timezone=*/"Europe/Zurich", /*locales=*/"zh");
160 }
161 
162 }  // namespace libtextclassifier3
163