1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/sentencepiece/normalizer.h"
18 
19 #include "utils/base/logging.h"
20 #include "utils/strings/utf8.h"
21 
22 namespace libtextclassifier3 {
23 
Normalize(StringPiece input,std::string * normalized_input) const24 bool SentencePieceNormalizer::Normalize(StringPiece input,
25                                         std::string* normalized_input) const {
26   // Ignores heading space.
27   if (remove_extra_whitespaces_) {
28     while (!input.empty()) {
29       std::pair<StringPiece, int> suffix_and_length;
30       if (!NormalizePrefix(input, &suffix_and_length)) {
31         TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
32         return false;
33       }
34       if (suffix_and_length.second <= 0) {
35         TC3_LOG(ERROR) << "Consumed string is empty.";
36         return false;
37       }
38       if (suffix_and_length.first.size() != 1 ||
39           suffix_and_length.first[0] != ' ') {
40         break;
41       }
42       input.RemovePrefix(suffix_and_length.second);
43     }
44   }
45 
46   if (input.empty()) {
47     *normalized_input = "";
48     return true;
49   }
50 
51   // Reserves the output buffer to avoid re-allocations.
52   const int kReservedSize = input.size() * 3;
53   normalized_input->reserve(kReservedSize);
54 
55   // Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK)
56   // if escape_whitespaces() is set (default = true).
57   const StringPiece kSpaceSymbol = "\xe2\x96\x81";
58 
59   // Adds a space symbol as a prefix (default is true)
60   // With this prefix, "world" and "hello world" are converted into
61   // "_world" and "_hello_world", which help the trainer to extract
62   // "_world" as one symbol.
63   if (add_dummy_prefix_) {
64     if (escape_whitespaces_) {
65       normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
66     } else {
67       normalized_input->append(" ");
68     }
69   }
70 
71   bool is_prev_space = remove_extra_whitespaces_;
72   while (!input.empty()) {
73     std::pair<StringPiece, int> p;
74     if (!NormalizePrefix(input, &p)) {
75       TC3_LOG(ERROR) << "Couldn't normalize string.";
76       return false;
77     }
78     if (p.second <= 0) {
79       TC3_LOG(ERROR) << "Consumed string is empty.";
80       return false;
81     }
82 
83     StringPiece sp = p.first;
84 
85     // Removes heading spaces in sentence piece,
86     // if the previous sentence piece ends with whitespace.
87     while (is_prev_space && ConsumePrefix(&sp, " ")) {
88     }
89 
90     if (!sp.empty()) {
91       const char* data = sp.data();
92       for (int n = 0; n < sp.size(); ++n) {
93         if (escape_whitespaces_ && data[n] == ' ') {
94           normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
95         } else {
96           *normalized_input += data[n];
97         }
98       }
99       // Checks whether the last character of sp is whitespace.
100       is_prev_space = EndsWith(sp, " ");
101     }
102     input.RemovePrefix(p.second);
103     is_prev_space = is_prev_space && remove_extra_whitespaces_;
104   }
105 
106   // Ignores tailing space.
107   if (remove_extra_whitespaces_) {
108     const StringPiece space = escape_whitespaces_ ? kSpaceSymbol : " ";
109     while (EndsWith(*normalized_input, space)) {
110       const int length = normalized_input->size() - space.size();
111       normalized_input->resize(length);
112     }
113   }
114   return true;
115 }
116 
NormalizePrefix(StringPiece input,std::pair<StringPiece,int> * prefix) const117 bool SentencePieceNormalizer::NormalizePrefix(
118     StringPiece input, std::pair<StringPiece, int>* prefix) const {
119   if (input.empty()) return true;
120   TrieMatch match;
121   if (!charsmap_trie_.LongestPrefixMatch(input, &match)) {
122     TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
123     return false;
124   }
125   const bool no_match = match.match_length <= 0;
126   if (no_match) {
127     const int char_length = ValidUTF8CharLength(input.data(), input.size());
128     if (char_length <= 0) {
129       // Found a malformed utf8.
130       // The rune is set to be 0xFFFD (REPLACEMENT CHARACTER),
131       // which is a valid Unicode of three bytes in utf8,
132       // but here we only consume one byte.
133       static const char kReplacementChar[] = "\xEF\xBF\xBD";
134       prefix->first = StringPiece(kReplacementChar, 3);
135       prefix->second = 1;  // Consumes 1 byte, buts emit 0xFFFD.
136     } else {
137       prefix->first = StringPiece(input.data(), char_length);
138       prefix->second = char_length;
139     }
140   } else {
141     if (match.id < 0 || match.id >= charsmap_normalized_.size()) {
142       TC3_LOG(ERROR) << "Invalid entry in normalization table.";
143       return false;
144     }
145     prefix->first = StringPiece(&charsmap_normalized_.data()[match.id]);
146     prefix->second = match.match_length;
147   }
148   return true;
149 }
150 
151 }  // namespace libtextclassifier3
152