1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
18 #define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
19 
20 #include <vector>
21 
22 #include "utils/base/logging.h"
23 #include "utils/sentencepiece/matcher.h"
24 #include "utils/strings/stringpiece.h"
25 
26 namespace libtextclassifier3 {
27 
28 // Encoder to segment/tokenize strings into pieces such that the sum of the
29 // scores of the pieces used is maximized.
30 class Encoder {
31  public:
32   // matcher: the list of valid sentence pieces represented as a matcher, e.g.
33   //     a trie.
34   // num_pieces: the number of pieces in the trie.
35   // pieces_scores: the scores of the individual pieces.
36   // start_code: code that is used as encoding of the start of input.
37   // end_code: code that is used as encoding of the end of input.
38   // encoding_offset: value added to the sentence piece ids to make them
39   //     not interesecting with start_code and end_code.
40   // unknown_code: code that is used for out-of-dictionary characters.
41   // unknown_score: the penality score associated with the unknown code.
42   Encoder(const SentencePieceMatcher* matcher, const int num_pieces,
43           const float* pieces_scores, int start_code = 0, int end_code = 1,
44           int encoding_offset = 2, int unknown_code = -1,
45           float unknown_score = 0.f)
num_pieces_(num_pieces)46       : num_pieces_(num_pieces),
47         scores_(pieces_scores),
48         matcher_(matcher),
49         start_code_(start_code),
50         end_code_(end_code),
51         encoding_offset_(encoding_offset),
52         unknown_code_(unknown_code),
53         unknown_score_(unknown_score) {}
54 
55   // Segment the input so that the total score of the pieces used is maximized.
56   // This is a simplified implementation of the general Viterbi algorithm,
57   // assuming independence between individual pieces.
58   bool Encode(StringPiece normalized_text,
59               std::vector<int>* encoded_text) const;
60 
61  private:
62   // State in the dynamic programming algorithm.
63   struct SegmentationEntry {
64     // Accumulated score.
65     float score;
66 
67     // Position before last piece.
68     int previous_pos;
69 
70     // Last piece used.
71     int piece_id;
72 
73     // Total number of pieces used.
74     int num_pieces;
75   };
76 
77   const int num_pieces_;
78   const float* scores_;
79   const SentencePieceMatcher* matcher_;
80   const int start_code_;
81   const int end_code_;
82   const int encoding_offset_;
83   const int unknown_code_;
84   const int unknown_score_;
85 };
86 
87 }  // namespace libtextclassifier3
88 
89 #endif  // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
90