1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
18 #define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
19 
20 #include <memory>
21 #include <vector>
22 
23 #include "base.h"
24 #include "common/vector-span.h"
25 #include "smartselect/types.h"
26 
27 namespace libtextclassifier {
28 
29 // Holds state for extracting features across multiple calls and reusing them.
30 // Assumes that features for each Token are independent.
31 class CachedFeatures {
32  public:
33   // Extracts the features for the given sequence of tokens.
34   //  - context_size: Specifies how many tokens to the left, and how many
35   //                   tokens to the right spans the context.
36   //  - sparse_features, dense_features: Extracted features for each token.
37   //  - feature_vector_fn: Writes features for given Token to the specified
38   //                       storage.
39   //                       NOTE: The function can assume that the underlying
40   //                       storage is initialized to all zeros.
41   //  - feature_vector_size: Size of a feature vector for one Token.
CachedFeatures(VectorSpan<Token> tokens,int context_size,const std::vector<std::vector<int>> & sparse_features,const std::vector<std::vector<float>> & dense_features,const std::function<bool (const std::vector<int> &,const std::vector<float> &,float *)> & feature_vector_fn,int feature_vector_size)42   CachedFeatures(VectorSpan<Token> tokens, int context_size,
43                  const std::vector<std::vector<int>>& sparse_features,
44                  const std::vector<std::vector<float>>& dense_features,
45                  const std::function<bool(const std::vector<int>&,
46                                           const std::vector<float>&, float*)>&
47                      feature_vector_fn,
48                  int feature_vector_size)
49       : tokens_(tokens),
50         context_size_(context_size),
51         feature_vector_size_(feature_vector_size),
52         remap_v0_feature_vector_(false),
53         remap_v0_chargram_embedding_size_(-1) {
54     Extract(sparse_features, dense_features, feature_vector_fn);
55   }
56 
57   // Gets a VectorSpan with the features for given click position.
58   bool Get(int click_pos, VectorSpan<float>* features,
59            VectorSpan<Token>* output_tokens);
60 
61   // Turns on a compatibility mode, which re-maps the extracted features to the
62   // v0 feature format (where the dense features were at the end).
63   // WARNING: Internally v0_feature_storage_ is used as a backing buffer for
64   // VectorSpan<float>, so the output of Extract is valid only until the next
65   // call or destruction of the current CachedFeatures object.
66   // TODO(zilka): Remove when we'll have retrained models.
SetV0FeatureMode(int chargram_embedding_size)67   void SetV0FeatureMode(int chargram_embedding_size) {
68     remap_v0_feature_vector_ = true;
69     remap_v0_chargram_embedding_size_ = chargram_embedding_size;
70     v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1));
71   }
72 
73  protected:
74   // Extracts features for all tokens and stores them for later retrieval.
75   void Extract(const std::vector<std::vector<int>>& sparse_features,
76                const std::vector<std::vector<float>>& dense_features,
77                const std::function<bool(const std::vector<int>&,
78                                         const std::vector<float>&, float*)>&
79                    feature_vector_fn);
80 
81   // Remaps extracted features to V0 feature format. The mapping is using
82   // the v0_feature_storage_ as the backing storage for the mapped features.
83   // For each token the features consist of:
84   //  - chargram embeddings
85   //  - dense features
86   // They are concatenated together as [chargram embeddings; dense features]
87   // for each token independently.
88   // The V0 features require that the chargram embeddings for tokens are
89   // concatenated first together, and at the end, the dense features for the
90   // tokens are concatenated to it.
91   void RemapV0FeatureVector(VectorSpan<float>* features);
92 
93  private:
94   const VectorSpan<Token> tokens_;
95   const int context_size_;
96   const int feature_vector_size_;
97   bool remap_v0_feature_vector_;
98   int remap_v0_chargram_embedding_size_;
99 
100   std::vector<float> features_;
101   std::vector<float> v0_feature_storage_;
102 };
103 
104 }  // namespace libtextclassifier
105 
106 #endif  // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
107