1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/vocab/vocab-level-table.h"
18 
19 #include <cstddef>
20 #include <memory>
21 
22 #include "annotator/model_generated.h"
23 #include "utils/base/endian.h"
24 #include "utils/container/bit-vector.h"
25 #include "utils/optional.h"
26 #include "marisa/trie.h"
27 
28 namespace libtextclassifier3 {
29 
Create(const VocabModel * model)30 std::unique_ptr<VocabLevelTable> VocabLevelTable::Create(
31     const VocabModel* model) {
32   if (!LittleEndian::IsLittleEndian()) {
33     // TODO(tonymak) Consider making this work on a big endian device.
34     TC3_LOG(ERROR)
35         << "VocabLevelTable is only working on a little endian device.";
36     return nullptr;
37   }
38   const flatbuffers::Vector<uint8_t>* trie_data = model->vocab_trie();
39   if (trie_data == nullptr) {
40     TC3_LOG(ERROR) << "vocab_trie is missing from the model file.";
41     return nullptr;
42   }
43   std::unique_ptr<marisa::Trie> vocab_trie(new marisa::Trie);
44   vocab_trie->map(trie_data->data(), trie_data->size());
45 
46   return std::unique_ptr<VocabLevelTable>(new VocabLevelTable(
47       model, std::move(vocab_trie), BitVector(model->beginner_level()),
48       BitVector(model->do_not_trigger_in_upper_case())));
49 }
50 
VocabLevelTable(const VocabModel * model,std::unique_ptr<marisa::Trie> vocab_trie,const BitVector beginner_level,const BitVector do_not_trigger_in_upper_case)51 VocabLevelTable::VocabLevelTable(const VocabModel* model,
52                                  std::unique_ptr<marisa::Trie> vocab_trie,
53                                  const BitVector beginner_level,
54                                  const BitVector do_not_trigger_in_upper_case)
55     : model_(model),
56       vocab_trie_(std::move(vocab_trie)),
57       beginner_level_(beginner_level),
58       do_not_trigger_in_upper_case_(do_not_trigger_in_upper_case) {}
59 
Lookup(const std::string & vocab) const60 Optional<LookupResult> VocabLevelTable::Lookup(const std::string& vocab) const {
61   marisa::Agent agent;
62   agent.set_query(vocab.data(), vocab.size());
63   if (vocab_trie_->lookup(agent)) {
64     const int vector_idx = agent.key().id();
65     return Optional<LookupResult>({beginner_level_[vector_idx],
66                                    do_not_trigger_in_upper_case_[vector_idx]});
67   }
68   return Optional<LookupResult>();
69 }
70 }  // namespace libtextclassifier3
71