1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_ 18 #define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_ 19 20 #include <memory> 21 #include <vector> 22 23 #include "base.h" 24 #include "common/vector-span.h" 25 #include "smartselect/types.h" 26 27 namespace libtextclassifier { 28 29 // Holds state for extracting features across multiple calls and reusing them. 30 // Assumes that features for each Token are independent. 31 class CachedFeatures { 32 public: 33 // Extracts the features for the given sequence of tokens. 34 // - context_size: Specifies how many tokens to the left, and how many 35 // tokens to the right spans the context. 36 // - sparse_features, dense_features: Extracted features for each token. 37 // - feature_vector_fn: Writes features for given Token to the specified 38 // storage. 39 // NOTE: The function can assume that the underlying 40 // storage is initialized to all zeros. 41 // - feature_vector_size: Size of a feature vector for one Token. CachedFeatures(VectorSpan<Token> tokens,int context_size,const std::vector<std::vector<int>> & sparse_features,const std::vector<std::vector<float>> & dense_features,const std::function<bool (const std::vector<int> &,const std::vector<float> &,float *)> & feature_vector_fn,int feature_vector_size)42 CachedFeatures(VectorSpan<Token> tokens, int context_size, 43 const std::vector<std::vector<int>>& sparse_features, 44 const std::vector<std::vector<float>>& dense_features, 45 const std::function<bool(const std::vector<int>&, 46 const std::vector<float>&, float*)>& 47 feature_vector_fn, 48 int feature_vector_size) 49 : tokens_(tokens), 50 context_size_(context_size), 51 feature_vector_size_(feature_vector_size), 52 remap_v0_feature_vector_(false), 53 remap_v0_chargram_embedding_size_(-1) { 54 Extract(sparse_features, dense_features, feature_vector_fn); 55 } 56 57 // Gets a VectorSpan with the features for given click position. 58 bool Get(int click_pos, VectorSpan<float>* features, 59 VectorSpan<Token>* output_tokens); 60 61 // Turns on a compatibility mode, which re-maps the extracted features to the 62 // v0 feature format (where the dense features were at the end). 63 // WARNING: Internally v0_feature_storage_ is used as a backing buffer for 64 // VectorSpan<float>, so the output of Extract is valid only until the next 65 // call or destruction of the current CachedFeatures object. 66 // TODO(zilka): Remove when we'll have retrained models. SetV0FeatureMode(int chargram_embedding_size)67 void SetV0FeatureMode(int chargram_embedding_size) { 68 remap_v0_feature_vector_ = true; 69 remap_v0_chargram_embedding_size_ = chargram_embedding_size; 70 v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1)); 71 } 72 73 protected: 74 // Extracts features for all tokens and stores them for later retrieval. 75 void Extract(const std::vector<std::vector<int>>& sparse_features, 76 const std::vector<std::vector<float>>& dense_features, 77 const std::function<bool(const std::vector<int>&, 78 const std::vector<float>&, float*)>& 79 feature_vector_fn); 80 81 // Remaps extracted features to V0 feature format. The mapping is using 82 // the v0_feature_storage_ as the backing storage for the mapped features. 83 // For each token the features consist of: 84 // - chargram embeddings 85 // - dense features 86 // They are concatenated together as [chargram embeddings; dense features] 87 // for each token independently. 88 // The V0 features require that the chargram embeddings for tokens are 89 // concatenated first together, and at the end, the dense features for the 90 // tokens are concatenated to it. 91 void RemapV0FeatureVector(VectorSpan<float>* features); 92 93 private: 94 const VectorSpan<Token> tokens_; 95 const int context_size_; 96 const int feature_vector_size_; 97 bool remap_v0_feature_vector_; 98 int remap_v0_chargram_embedding_size_; 99 100 std::vector<float> features_; 101 std::vector<float> v0_feature_storage_; 102 }; 103 104 } // namespace libtextclassifier 105 106 #endif // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_ 107