1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/fb_model/model-provider-from-fb.h"
18 
19 #include "lang_id/common/file/file-utils.h"
20 #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
21 #include "lang_id/common/flatbuffers/model-utils.h"
22 #include "lang_id/common/lite_strings/str-split.h"
23 
24 namespace libtextclassifier3 {
25 namespace mobile {
26 namespace lang_id {
27 
ModelProviderFromFlatbuffer(const string & filename)28 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(const string &filename)
29 
30     // Using mmap as a fast way to read the model bytes.  As the file is
31     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
32     // stay alive for the entire lifetime of this object.
33     : scoped_mmap_(new ScopedMmap(filename)) {
34   Initialize(scoped_mmap_->handle().to_stringpiece());
35 }
36 
ModelProviderFromFlatbuffer(int fd)37 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(int fd)
38 
39     // Using mmap as a fast way to read the model bytes.  As the file is
40     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
41     // stay alive for the entire lifetime of this object.
42     : scoped_mmap_(new ScopedMmap(fd)) {
43   Initialize(scoped_mmap_->handle().to_stringpiece());
44 }
45 
Initialize(StringPiece model_bytes)46 void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
47   // Note: valid_ was initialized to false.  In the code below, we set valid_ to
48   // true only if all initialization steps completed successfully.  Otherwise,
49   // we return early, leaving valid_ to its default value false.
50   model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
51   if (model_ == nullptr) {
52     SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
53     return;
54   }
55 
56   // Initialize context_ parameters.
57   if (!saft_fbs::FillParameters(*model_, &context_)) {
58     // FillParameters already performs error logging.
59     return;
60   }
61 
62   // Init languages_.
63   const string known_languages_str = context_.Get("supported_languages", "");
64   for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
65     languages_.emplace_back(sp);
66   }
67   if (languages_.empty()) {
68     SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
69     return;
70   }
71 
72   // Init nn_params_.
73   if (!InitNetworkParams()) {
74     // InitNetworkParams already performs error logging.
75     return;
76   }
77 
78   // Everything looks fine.
79   valid_ = true;
80 }
81 
InitNetworkParams()82 bool ModelProviderFromFlatbuffer::InitNetworkParams() {
83   const string kInputName = "language-identifier-network";
84   StringPiece bytes =
85       saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
86   if ((bytes.data() == nullptr) || bytes.empty()) {
87     SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
88     return false;
89   }
90   std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
91       new EmbeddingNetworkParamsFromFlatbuffer(bytes));
92   if (!nn_params_from_fb->is_valid()) {
93     SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
94     return false;
95   }
96   nn_params_ = std::move(nn_params_from_fb);
97   return true;
98 }
99 
100 }  // namespace lang_id
101 }  // namespace mobile
102 }  // namespace nlp_saft
103