1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/sentencepiece/normalizer.h"
18
19 #include "utils/base/logging.h"
20 #include "utils/strings/utf8.h"
21
22 namespace libtextclassifier3 {
23
Normalize(StringPiece input,std::string * normalized_input) const24 bool SentencePieceNormalizer::Normalize(StringPiece input,
25 std::string* normalized_input) const {
26 // Ignores heading space.
27 if (remove_extra_whitespaces_) {
28 while (!input.empty()) {
29 std::pair<StringPiece, int> suffix_and_length;
30 if (!NormalizePrefix(input, &suffix_and_length)) {
31 TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
32 return false;
33 }
34 if (suffix_and_length.second <= 0) {
35 TC3_LOG(ERROR) << "Consumed string is empty.";
36 return false;
37 }
38 if (suffix_and_length.first.size() != 1 ||
39 suffix_and_length.first[0] != ' ') {
40 break;
41 }
42 input.RemovePrefix(suffix_and_length.second);
43 }
44 }
45
46 if (input.empty()) {
47 *normalized_input = "";
48 return true;
49 }
50
51 // Reserves the output buffer to avoid re-allocations.
52 const int kReservedSize = input.size() * 3;
53 normalized_input->reserve(kReservedSize);
54
55 // Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK)
56 // if escape_whitespaces() is set (default = true).
57 const StringPiece kSpaceSymbol = "\xe2\x96\x81";
58
59 // Adds a space symbol as a prefix (default is true)
60 // With this prefix, "world" and "hello world" are converted into
61 // "_world" and "_hello_world", which help the trainer to extract
62 // "_world" as one symbol.
63 if (add_dummy_prefix_) {
64 if (escape_whitespaces_) {
65 normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
66 } else {
67 normalized_input->append(" ");
68 }
69 }
70
71 bool is_prev_space = remove_extra_whitespaces_;
72 while (!input.empty()) {
73 std::pair<StringPiece, int> p;
74 if (!NormalizePrefix(input, &p)) {
75 TC3_LOG(ERROR) << "Couldn't normalize string.";
76 return false;
77 }
78 if (p.second <= 0) {
79 TC3_LOG(ERROR) << "Consumed string is empty.";
80 return false;
81 }
82
83 StringPiece sp = p.first;
84
85 // Removes heading spaces in sentence piece,
86 // if the previous sentence piece ends with whitespace.
87 while (is_prev_space && ConsumePrefix(&sp, " ")) {
88 }
89
90 if (!sp.empty()) {
91 const char* data = sp.data();
92 for (int n = 0; n < sp.size(); ++n) {
93 if (escape_whitespaces_ && data[n] == ' ') {
94 normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
95 } else {
96 *normalized_input += data[n];
97 }
98 }
99 // Checks whether the last character of sp is whitespace.
100 is_prev_space = EndsWith(sp, " ");
101 }
102 input.RemovePrefix(p.second);
103 is_prev_space = is_prev_space && remove_extra_whitespaces_;
104 }
105
106 // Ignores tailing space.
107 if (remove_extra_whitespaces_) {
108 const StringPiece space = escape_whitespaces_ ? kSpaceSymbol : " ";
109 while (EndsWith(*normalized_input, space)) {
110 const int length = normalized_input->size() - space.size();
111 normalized_input->resize(length);
112 }
113 }
114 return true;
115 }
116
NormalizePrefix(StringPiece input,std::pair<StringPiece,int> * prefix) const117 bool SentencePieceNormalizer::NormalizePrefix(
118 StringPiece input, std::pair<StringPiece, int>* prefix) const {
119 if (input.empty()) return true;
120 TrieMatch match;
121 if (!charsmap_trie_.LongestPrefixMatch(input, &match)) {
122 TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
123 return false;
124 }
125 const bool no_match = match.match_length <= 0;
126 if (no_match) {
127 const int char_length = ValidUTF8CharLength(input.data(), input.size());
128 if (char_length <= 0) {
129 // Found a malformed utf8.
130 // The rune is set to be 0xFFFD (REPLACEMENT CHARACTER),
131 // which is a valid Unicode of three bytes in utf8,
132 // but here we only consume one byte.
133 static const char kReplacementChar[] = "\xEF\xBF\xBD";
134 prefix->first = StringPiece(kReplacementChar, 3);
135 prefix->second = 1; // Consumes 1 byte, buts emit 0xFFFD.
136 } else {
137 prefix->first = StringPiece(input.data(), char_length);
138 prefix->second = char_length;
139 }
140 } else {
141 if (match.id < 0 || match.id >= charsmap_normalized_.size()) {
142 TC3_LOG(ERROR) << "Invalid entry in normalization table.";
143 return false;
144 }
145 prefix->first = StringPiece(&charsmap_normalized_.data()[match.id]);
146 prefix->second = match.match_length;
147 }
148 return true;
149 }
150
151 } // namespace libtextclassifier3
152