1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/number/number.h"
18 
19 #include <climits>
20 #include <cstdlib>
21 
22 #include "annotator/collections.h"
23 #include "utils/base/logging.h"
24 
25 namespace libtextclassifier3 {
26 
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const27 bool NumberAnnotator::ClassifyText(
28     const UnicodeText& context, CodepointSpan selection_indices,
29     AnnotationUsecase annotation_usecase,
30     ClassificationResult* classification_result) const {
31   int64 parsed_value;
32   int num_prefix_codepoints;
33   int num_suffix_codepoints;
34   if (ParseNumber(UnicodeText::Substring(context, selection_indices.first,
35                                          selection_indices.second),
36                   &parsed_value, &num_prefix_codepoints,
37                   &num_suffix_codepoints)) {
38     ClassificationResult classification{Collections::Number(), 1.0};
39     TC3_CHECK(classification_result != nullptr);
40     classification_result->collection = Collections::Number();
41     classification_result->score = options_->score();
42     classification_result->priority_score = options_->priority_score();
43     classification_result->numeric_value = parsed_value;
44     return true;
45   }
46   return false;
47 }
48 
FindAll(const UnicodeText & context,AnnotationUsecase annotation_usecase,std::vector<AnnotatedSpan> * result) const49 bool NumberAnnotator::FindAll(const UnicodeText& context,
50                               AnnotationUsecase annotation_usecase,
51                               std::vector<AnnotatedSpan>* result) const {
52   if (!options_->enabled() || ((1 << annotation_usecase) &
53                                options_->enabled_annotation_usecases()) == 0) {
54     return true;
55   }
56 
57   const std::vector<Token> tokens = feature_processor_->Tokenize(context);
58   for (const Token& token : tokens) {
59     const UnicodeText token_text =
60         UTF8ToUnicodeText(token.value, /*do_copy=*/false);
61     int64 parsed_value;
62     int num_prefix_codepoints;
63     int num_suffix_codepoints;
64     if (ParseNumber(token_text, &parsed_value, &num_prefix_codepoints,
65                     &num_suffix_codepoints)) {
66       ClassificationResult classification{Collections::Number(),
67                                           options_->score()};
68       classification.numeric_value = parsed_value;
69       classification.priority_score = options_->priority_score();
70 
71       AnnotatedSpan annotated_span;
72       annotated_span.span = {token.start + num_prefix_codepoints,
73                              token.end - num_suffix_codepoints};
74       annotated_span.classification.push_back(classification);
75 
76       result->push_back(annotated_span);
77     }
78   }
79 
80   return true;
81 }
82 
FlatbuffersVectorToSet(const flatbuffers::Vector<int32_t> * codepoints)83 std::unordered_set<int> NumberAnnotator::FlatbuffersVectorToSet(
84     const flatbuffers::Vector<int32_t>* codepoints) {
85   if (codepoints == nullptr) {
86     return std::unordered_set<int>{};
87   }
88 
89   std::unordered_set<int> result;
90   for (const int codepoint : *codepoints) {
91     result.insert(codepoint);
92   }
93   return result;
94 }
95 
96 namespace {
ConsumeAndParseNumber(const UnicodeText::const_iterator & it_begin,const UnicodeText::const_iterator & it_end,int64 * result)97 UnicodeText::const_iterator ConsumeAndParseNumber(
98     const UnicodeText::const_iterator& it_begin,
99     const UnicodeText::const_iterator& it_end, int64* result) {
100   *result = 0;
101 
102   // See if there's a sign in the beginning of the number.
103   int sign = 1;
104   auto it = it_begin;
105   if (it != it_end) {
106     if (*it == '-') {
107       ++it;
108       sign = -1;
109     } else if (*it == '+') {
110       ++it;
111       sign = 1;
112     }
113   }
114 
115   while (it != it_end) {
116     if (*it >= '0' && *it <= '9') {
117       // When overflow is imminent we'll fail to parse the number.
118       if (*result > INT64_MAX / 10) {
119         return it_begin;
120       }
121       *result *= 10;
122       *result += *it - '0';
123     } else {
124       *result *= sign;
125       return it;
126     }
127 
128     ++it;
129   }
130 
131   *result *= sign;
132   return it_end;
133 }
134 }  // namespace
135 
ParseNumber(const UnicodeText & text,int64 * result,int * num_prefix_codepoints,int * num_suffix_codepoints) const136 bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* result,
137                                   int* num_prefix_codepoints,
138                                   int* num_suffix_codepoints) const {
139   TC3_CHECK(result != nullptr && num_prefix_codepoints != nullptr &&
140             num_suffix_codepoints != nullptr);
141   auto it = text.begin();
142   auto it_end = text.end();
143 
144   // Strip boundary codepoints from both ends.
145   const CodepointSpan original_span{0, text.size_codepoints()};
146   const CodepointSpan stripped_span =
147       feature_processor_->StripBoundaryCodepoints(text, original_span);
148   const int num_stripped_end = (original_span.second - stripped_span.second);
149   std::advance(it, stripped_span.first);
150   std::advance(it_end, -num_stripped_end);
151 
152   // Consume prefix codepoints.
153   *num_prefix_codepoints = stripped_span.first;
154   while (it != text.end()) {
155     if (allowed_prefix_codepoints_.find(*it) ==
156         allowed_prefix_codepoints_.end()) {
157       break;
158     }
159 
160     ++it;
161     ++(*num_prefix_codepoints);
162   }
163 
164   auto it_start = it;
165   it = ConsumeAndParseNumber(it, text.end(), result);
166   if (it == it_start) {
167     return false;
168   }
169 
170   // Consume suffix codepoints.
171   bool valid_suffix = true;
172   *num_suffix_codepoints = 0;
173   while (it != it_end) {
174     if (allowed_suffix_codepoints_.find(*it) ==
175         allowed_suffix_codepoints_.end()) {
176       valid_suffix = false;
177       break;
178     }
179 
180     ++it;
181     ++(*num_suffix_codepoints);
182   }
183   *num_suffix_codepoints += num_stripped_end;
184   return valid_suffix;
185 }
186 
187 }  // namespace libtextclassifier3
188