1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
18 #define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
19 
20 #include <string>
21 #include <vector>
22 
23 #include "lang_id/common/embedding-network-params.h"
24 
25 namespace libtextclassifier3 {
26 namespace mobile {
27 namespace lang_id {
28 
29 // Interface for accessing parameters for the LangId model.
30 //
31 // Note: some clients prefer to include the model parameters in the binary,
32 // others prefer loading them from a separate file.  This file provides a common
33 // interface for these alternative mechanisms.
34 class ModelProvider {
35  public:
36   virtual ~ModelProvider() = default;
37 
38   // Returns true if this ModelProvider has been succesfully constructed (e.g.,
39   // can return false if an underlying model file could not be read).  Clients
40   // should not use invalid ModelProviders.
is_valid()41   bool is_valid() { return valid_; }
42 
43   // Returns the TaskContext with parameters for the LangId model.  E.g., one
44   // important parameter specifies the features to use.
45   virtual const TaskContext *GetTaskContext() const = 0;
46 
47   // Returns parameters for the underlying Neurosis feed-forward neural network.
48   virtual const EmbeddingNetworkParams *GetNnParams() const = 0;
49 
50   // Returns list of languages recognized by the model.  Each element of the
51   // returned vector should be a BCP-47 language code (e.g., "en", "ro", etc).
52   // Language at index i from the returned vector corresponds to softmax label
53   // i.
54   virtual std::vector<string> GetLanguages() const = 0;
55 
56  protected:
57   bool valid_ = false;
58 };
59 
60 }  // namespace lang_id
61 }  // namespace mobile
62 }  // namespace nlp_saft
63 
64 #endif  // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
65