1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Generic feature extractor for extracting features from objects. The feature
18 // extractor can be used for extracting features from any object. The feature
19 // extractor and feature function classes are template classes that have to
20 // be instantiated for extracting feature from a specific object type.
21 //
22 // A feature extractor consists of a hierarchy of feature functions. Each
23 // feature function extracts one or more feature type and value pairs from the
24 // object.
25 //
26 // The feature extractor has a modular design where new feature functions can be
27 // registered as components. The feature extractor is initialized from a
28 // descriptor represented by a protocol buffer. The feature extractor can also
29 // be initialized from a text-based source specification of the feature
30 // extractor. Feature specification parsers can be added as components. By
31 // default the feature extractor can be read from an ASCII protocol buffer or in
32 // a simple feature modeling language (fml).
33 
34 // A feature function is invoked with a focus. Nested feature function can be
35 // invoked with another focus determined by the parent feature function.
36 
37 #ifndef LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_
38 #define LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_
39 
40 #include <stddef.h>
41 
42 #include <string>
43 #include <vector>
44 
45 #include "common/feature-descriptors.h"
46 #include "common/feature-types.h"
47 #include "common/fml-parser.h"
48 #include "common/registry.h"
49 #include "common/task-context.h"
50 #include "common/workspace.h"
51 #include "util/base/integral_types.h"
52 #include "util/base/logging.h"
53 #include "util/base/macros.h"
54 #include "util/gtl/stl_util.h"
55 
56 namespace libtextclassifier {
57 namespace nlp_core {
58 
59 typedef int64 Predicate;
60 typedef Predicate FeatureValue;
61 
62 // A union used to represent discrete and continuous feature values.
63 union FloatFeatureValue {
64  public:
FloatFeatureValue(FeatureValue v)65   explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
FloatFeatureValue(uint32 i,float w)66   FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
67   FeatureValue discrete_value;
68   struct {
69     uint32 id;
70     float weight;
71   };
72 };
73 
74 // A feature vector contains feature type and value pairs.
75 class FeatureVector {
76  public:
FeatureVector()77   FeatureVector() {}
78 
79   // Adds feature type and value pair to feature vector.
add(FeatureType * type,FeatureValue value)80   void add(FeatureType *type, FeatureValue value) {
81     features_.emplace_back(type, value);
82   }
83 
84   // Removes all elements from the feature vector.
clear()85   void clear() { features_.clear(); }
86 
87   // Returns the number of elements in the feature vector.
size()88   int size() const { return features_.size(); }
89 
90   // Reserves space in the underlying feature vector.
reserve(int n)91   void reserve(int n) { features_.reserve(n); }
92 
93   // Returns feature type for an element in the feature vector.
type(int index)94   FeatureType *type(int index) const { return features_[index].type; }
95 
96   // Returns feature value for an element in the feature vector.
value(int index)97   FeatureValue value(int index) const { return features_[index].value; }
98 
99  private:
100   // Structure for holding feature type and value pairs.
101   struct Element {
ElementElement102     Element() : type(nullptr), value(-1) {}
ElementElement103     Element(FeatureType *t, FeatureValue v) : type(t), value(v) {}
104 
105     FeatureType *type;
106     FeatureValue value;
107   };
108 
109   // Array for storing feature vector elements.
110   std::vector<Element> features_;
111 
112   TC_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
113 };
114 
115 // The generic feature extractor is the type-independent part of a feature
116 // extractor. This holds the descriptor for the feature extractor and the
117 // collection of feature types used in the feature extractor.  The feature
118 // types are not available until FeatureExtractor<>::Init() has been called.
119 class GenericFeatureExtractor {
120  public:
121   GenericFeatureExtractor();
122   virtual ~GenericFeatureExtractor();
123 
124   // Initializes the feature extractor from an FML string specification.  For
125   // the FML specification grammar, see fml-parser.h.
126   //
127   // Returns true on success, false on syntax error.
128   bool Parse(const std::string &source);
129 
130   // Returns the feature extractor descriptor.
descriptor()131   const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
mutable_descriptor()132   FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; }
133 
134   // Returns the number of feature types in the feature extractor.  Invalid
135   // before Init() has been called.
feature_types()136   int feature_types() const { return feature_types_.size(); }
137 
138   // Returns a feature type used in the extractor.  Invalid before Init() has
139   // been called.
feature_type(int index)140   const FeatureType *feature_type(int index) const {
141     return feature_types_[index];
142   }
143 
144   // Returns the feature domain size of this feature extractor.
145   // NOTE: The way that domain size is calculated is, for some, unintuitive. It
146   // is the largest domain size of any feature type.
147   FeatureValue GetDomainSize() const;
148 
149  protected:
150   // Initializes the feature types used by the extractor.  Called from
151   // FeatureExtractor<>::Init().
152   //
153   // Returns true on success, false on error.
154   bool InitializeFeatureTypes();
155 
156  private:
157   // Initializes the top-level feature functions.
158   virtual bool InitializeFeatureFunctions() = 0;
159 
160   // Returns all feature types used by the extractor. The feature types are
161   // added to the result array.
162   virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0;
163 
164   // Descriptor for the feature extractor. This is a protocol buffer that
165   // contains all the information about the feature extractor. The feature
166   // functions are initialized from the information in the descriptor.
167   FeatureExtractorDescriptor descriptor_;
168 
169   // All feature types used by the feature extractor. The collection of all the
170   // feature types describes the feature space of the feature set produced by
171   // the feature extractor.  Not owned.
172   std::vector<FeatureType *> feature_types_;
173 
174   TC_DISALLOW_COPY_AND_ASSIGN(GenericFeatureExtractor);
175 };
176 
177 // The generic feature function is the type-independent part of a feature
178 // function. Each feature function is associated with the descriptor that it is
179 // instantiated from.  The feature types associated with this feature function
180 // will be established by the time FeatureExtractor<>::Init() completes.
181 class GenericFeatureFunction {
182  public:
183   // A feature value that represents the absence of a value.
184   static constexpr FeatureValue kNone = -1;
185 
186   GenericFeatureFunction();
187   virtual ~GenericFeatureFunction();
188 
189   // Sets up the feature function. NB: FeatureTypes of nested functions are not
190   // guaranteed to be available until Init().
191   //
192   // Returns true on success, false on error.
Setup(TaskContext * context)193   virtual bool Setup(TaskContext *context) { return true; }
194 
195   // Initializes the feature function. NB: The FeatureType of this function must
196   // be established when this method completes.
197   //
198   // Returns true on success, false on error.
Init(TaskContext * context)199   virtual bool Init(TaskContext *context) { return true; }
200 
201   // Requests workspaces from a registry to obtain indices into a WorkspaceSet
202   // for any Workspace objects used by this feature function. NB: This will be
203   // called after Init(), so it can depend on resources and arguments.
RequestWorkspaces(WorkspaceRegistry * registry)204   virtual void RequestWorkspaces(WorkspaceRegistry *registry) {}
205 
206   // Appends the feature types produced by the feature function to types.  The
207   // default implementation appends feature_type(), if non-null.  Invalid
208   // before Init() has been called.
209   virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const;
210 
211   // Returns the feature type for feature produced by this feature function. If
212   // the feature function produces features of different types this returns
213   // null.  Invalid before Init() has been called.
214   virtual FeatureType *GetFeatureType() const;
215 
216   // Returns the name of the registry used for creating the feature function.
217   // This can be used for checking if two feature functions are of the same
218   // kind.
219   virtual const char *RegistryName() const = 0;
220 
221   // Returns the value of a named parameter from the feature function
222   // descriptor.  Returns empty string ("") if parameter is not found.
223   std::string GetParameter(const std::string &name) const;
224 
225   // Returns the int value of a named parameter from the feature function
226   // descriptor.  Returns default_value if the parameter is not found or if its
227   // value can't be parsed as an int.
228   int GetIntParameter(const std::string &name, int default_value) const;
229 
230   // Returns the bool value of a named parameter from the feature function
231   // descriptor.  Returns default_value if the parameter is not found or if its
232   // value is not "true" or "false".
233   bool GetBoolParameter(const std::string &name, bool default_value) const;
234 
235   // Returns the FML function description for the feature function, i.e. the
236   // name and parameters without the nested features.
FunctionName()237   std::string FunctionName() const {
238     std::string output;
239     ToFMLFunction(*descriptor_, &output);
240     return output;
241   }
242 
243   // Returns the prefix for nested feature functions. This is the prefix of this
244   // feature function concatenated with the feature function name.
SubPrefix()245   std::string SubPrefix() const {
246     return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
247   }
248 
249   // Returns/sets the feature extractor this function belongs to.
extractor()250   GenericFeatureExtractor *extractor() const { return extractor_; }
set_extractor(GenericFeatureExtractor * extractor)251   void set_extractor(GenericFeatureExtractor *extractor) {
252     extractor_ = extractor;
253   }
254 
255   // Returns/sets the feature function descriptor.
descriptor()256   FeatureFunctionDescriptor *descriptor() const { return descriptor_; }
set_descriptor(FeatureFunctionDescriptor * descriptor)257   void set_descriptor(FeatureFunctionDescriptor *descriptor) {
258     descriptor_ = descriptor;
259   }
260 
261   // Returns a descriptive name for the feature function. The name is taken from
262   // the descriptor for the feature function. If the name is empty or the
263   // feature function is a variable the name is the FML representation of the
264   // feature, including the prefix.
265   std::string name() const;
266 
267   // Returns the argument from the feature function descriptor. It defaults to
268   // 0 if the argument has not been specified.
argument()269   int argument() const {
270     return descriptor_->has_argument() ? descriptor_->argument() : 0;
271   }
272 
273   // Returns/sets/clears function name prefix.
prefix()274   const std::string &prefix() const { return prefix_; }
set_prefix(const std::string & prefix)275   void set_prefix(const std::string &prefix) { prefix_ = prefix; }
276 
277  protected:
278   // Returns the feature type for single-type feature functions.
feature_type()279   FeatureType *feature_type() const { return feature_type_; }
280 
281   // Sets the feature type for single-type feature functions.  This takes
282   // ownership of feature_type.  Can only be called once with a non-null
283   // pointer.
set_feature_type(FeatureType * feature_type)284   void set_feature_type(FeatureType *feature_type) {
285     TC_DCHECK_NE(feature_type, nullptr);
286     feature_type_ = feature_type;
287   }
288 
289  private:
290   // Feature extractor this feature function belongs to.  Not owned.
291   GenericFeatureExtractor *extractor_ = nullptr;
292 
293   // Descriptor for feature function.  Not owned.
294   FeatureFunctionDescriptor *descriptor_ = nullptr;
295 
296   // Feature type for features produced by this feature function. If the
297   // feature function produces features of multiple feature types this is null
298   // and the feature function must return it's feature types in
299   // GetFeatureTypes().  Owned.
300   FeatureType *feature_type_ = nullptr;
301 
302   // Prefix used for sub-feature types of this function.
303   std::string prefix_;
304 };
305 
306 // Feature function that can extract features from an object.  Templated on
307 // two type arguments:
308 //
309 // OBJ:  The "object" from which features are extracted; e.g., a sentence.  This
310 //       should be a plain type, rather than a reference or pointer.
311 //
312 // ARGS: A set of 0 or more types that are used to "index" into some part of the
313 //       object that should be extracted, e.g. an int token index for a sentence
314 //       object.  This should not be a reference type.
315 template <class OBJ, class... ARGS>
316 class FeatureFunction
317     : public GenericFeatureFunction,
318       public RegisterableClass<FeatureFunction<OBJ, ARGS...> > {
319  public:
320   using Self = FeatureFunction<OBJ, ARGS...>;
321 
322   // Preprocesses the object.  This will be called prior to calling Evaluate()
323   // or Compute() on that object.
Preprocess(WorkspaceSet * workspaces,OBJ * object)324   virtual void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {}
325 
326   // Appends features computed from the object and focus to the result.  The
327   // default implementation delegates to Compute(), adding a single value if
328   // available.  Multi-valued feature functions must override this method.
Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)329   virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
330                         ARGS... args, FeatureVector *result) const {
331     FeatureValue value = Compute(workspaces, object, args..., result);
332     if (value != kNone) result->add(feature_type(), value);
333   }
334 
335   // Returns a feature value computed from the object and focus, or kNone if no
336   // value is computed.  Single-valued feature functions only need to override
337   // this method.
Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,const FeatureVector * fv)338   virtual FeatureValue Compute(const WorkspaceSet &workspaces,
339                                const OBJ &object, ARGS... args,
340                                const FeatureVector *fv) const {
341     return kNone;
342   }
343 
344   // Instantiates a new feature function in a feature extractor from a feature
345   // descriptor.
Instantiate(GenericFeatureExtractor * extractor,FeatureFunctionDescriptor * fd,const std::string & prefix)346   static Self *Instantiate(GenericFeatureExtractor *extractor,
347                            FeatureFunctionDescriptor *fd,
348                            const std::string &prefix) {
349     Self *f = Self::Create(fd->type());
350     if (f != nullptr) {
351       f->set_extractor(extractor);
352       f->set_descriptor(fd);
353       f->set_prefix(prefix);
354     }
355     return f;
356   }
357 
358   // Returns the name of the registry for the feature function.
RegistryName()359   const char *RegistryName() const override { return Self::registry()->name(); }
360 
361  private:
362   // Special feature function class for resolving variable references. The type
363   // of the feature function is used for resolving the variable reference. When
364   // evaluated it will either get the feature value(s) from the variable portion
365   // of the feature vector, if present, or otherwise it will call the referenced
366   // feature extractor function directly to extract the feature(s).
367   class Reference;
368 };
369 
370 // Base class for features with nested feature functions. The nested functions
371 // are of type NES, which may be different from the type of the parent function.
372 // NB: NestedFeatureFunction will ensure that all initialization of nested
373 // functions takes place during Setup() and Init() -- after the nested features
374 // are initialized, the parent feature is initialized via SetupNested() and
375 // InitNested(). Alternatively, a derived classes that overrides Setup() and
376 // Init() directly should call Parent::Setup(), Parent::Init(), etc. first.
377 //
378 // Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or
379 // Compute, since the nested functions may be of a different type.
380 template <class NES, class OBJ, class... ARGS>
381 class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
382  public:
383   using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>;
384 
385   // Clean up nested functions.
~NestedFeatureFunction()386   ~NestedFeatureFunction() override {
387     // Fully qualified class name, to avoid an ambiguity error when building for
388     // Android.
389     ::libtextclassifier::STLDeleteElements(&nested_);
390   }
391 
392   // By default, just appends the nested feature types.
GetFeatureTypes(std::vector<FeatureType * > * types)393   void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
394     // It's odd if a NestedFeatureFunction does not have anything nested inside
395     // it, so we crash in debug mode.  Still, nothing should crash in prod mode.
396     TC_DCHECK(!this->nested().empty())
397         << "Nested features require nested features to be defined.";
398     for (auto *function : nested_) function->GetFeatureTypes(types);
399   }
400 
401   // Sets up the nested features.
Setup(TaskContext * context)402   bool Setup(TaskContext *context) override {
403     bool success = CreateNested(this->extractor(), this->descriptor(), &nested_,
404                                 this->SubPrefix());
405     if (!success) {
406       return false;
407     }
408     for (auto *function : nested_) {
409       if (!function->Setup(context)) return false;
410     }
411     if (!SetupNested(context)) {
412       return false;
413     }
414     return true;
415   }
416 
417   // Sets up this NestedFeatureFunction specifically.
SetupNested(TaskContext * context)418   virtual bool SetupNested(TaskContext *context) { return true; }
419 
420   // Initializes the nested features.
Init(TaskContext * context)421   bool Init(TaskContext *context) override {
422     for (auto *function : nested_) {
423       if (!function->Init(context)) return false;
424     }
425     if (!InitNested(context)) return false;
426     return true;
427   }
428 
429   // Initializes this NestedFeatureFunction specifically.
InitNested(TaskContext * context)430   virtual bool InitNested(TaskContext *context) { return true; }
431 
432   // Gets all the workspaces needed for the nested functions.
RequestWorkspaces(WorkspaceRegistry * registry)433   void RequestWorkspaces(WorkspaceRegistry *registry) override {
434     for (auto *function : nested_) function->RequestWorkspaces(registry);
435   }
436 
437   // Returns the list of nested feature functions.
nested()438   const std::vector<NES *> &nested() const { return nested_; }
439 
440   // Instantiates nested feature functions for a feature function. Creates and
441   // initializes one feature function for each sub-descriptor in the feature
442   // descriptor.
CreateNested(GenericFeatureExtractor * extractor,FeatureFunctionDescriptor * fd,std::vector<NES * > * functions,const std::string & prefix)443   static bool CreateNested(GenericFeatureExtractor *extractor,
444                            FeatureFunctionDescriptor *fd,
445                            std::vector<NES *> *functions,
446                            const std::string &prefix) {
447     for (int i = 0; i < fd->feature_size(); ++i) {
448       FeatureFunctionDescriptor *sub = fd->mutable_feature(i);
449       NES *f = NES::Instantiate(extractor, sub, prefix);
450       if (f == nullptr) {
451         return false;
452       }
453       functions->push_back(f);
454     }
455     return true;
456   }
457 
458  protected:
459   // The nested feature functions, if any, in order of declaration in the
460   // feature descriptor.  Owned.
461   std::vector<NES *> nested_;
462 };
463 
464 // Base class for a nested feature function that takes nested features with the
465 // same signature as these features, i.e. a meta feature. For this class, we can
466 // provide preprocessing of the nested features.
467 template <class OBJ, class... ARGS>
468 class MetaFeatureFunction
469     : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ,
470                                    ARGS...> {
471  public:
472   // Preprocesses using the nested features.
Preprocess(WorkspaceSet * workspaces,OBJ * object)473   void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
474     for (auto *function : this->nested_) {
475       function->Preprocess(workspaces, object);
476     }
477   }
478 };
479 
480 // Template for a special type of locator: The locator of type
481 // FeatureFunction<OBJ, ARGS...> calls nested functions of type
482 // FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is
483 // responsible for translating by providing the following:
484 //
485 // // Gets the new additional focus.
486 // IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object);
487 //
488 // This is useful to e.g. add a token focus to a parser state based on some
489 // desired property of that state.
490 template <class DER, class OBJ, class IDX, class... ARGS>
491 class FeatureAddFocusLocator
492     : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ,
493                                    ARGS...> {
494  public:
Preprocess(WorkspaceSet * workspaces,OBJ * object)495   void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
496     for (auto *function : this->nested_) {
497       function->Preprocess(workspaces, object);
498     }
499   }
500 
Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)501   void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
502                 FeatureVector *result) const override {
503     IDX focus =
504         static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
505     for (auto *function : this->nested()) {
506       function->Evaluate(workspaces, object, focus, args..., result);
507     }
508   }
509 
510   // Returns the first nested feature's computed value.
Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,const FeatureVector * result)511   FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
512                        ARGS... args,
513                        const FeatureVector *result) const override {
514     IDX focus =
515         static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
516     return this->nested()[0]->Compute(workspaces, object, focus, args...,
517                                       result);
518   }
519 };
520 
521 // CRTP feature locator class. This is a meta feature that modifies ARGS and
522 // then calls the nested feature functions with the modified ARGS. Note that in
523 // order for this template to work correctly, all of ARGS must be types for
524 // which the reference operator & can be interpreted as a pointer to the
525 // argument. The derived class DER must implement the UpdateFocus method which
526 // takes pointers to the ARGS arguments:
527 //
528 // // Updates the current arguments.
529 // void UpdateArgs(const OBJ &object, ARGS *...args) const;
530 template <class DER, class OBJ, class... ARGS>
531 class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
532  public:
533   // Feature locators have an additional check that there is no intrinsic type,
534   // but only in debug mode: having an intrinsic type here is odd, but not
535   // enough to motive a crash in prod.
GetFeatureTypes(std::vector<FeatureType * > * types)536   void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
537     TC_DCHECK_EQ(this->feature_type(), nullptr)
538         << "FeatureLocators should not have an intrinsic type.";
539     MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
540   }
541 
542   // Evaluates the locator.
Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)543   void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
544                 FeatureVector *result) const override {
545     static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
546     for (auto *function : this->nested()) {
547       function->Evaluate(workspaces, object, args..., result);
548     }
549   }
550 
551   // Returns the first nested feature's computed value.
Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,const FeatureVector * result)552   FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
553                        ARGS... args,
554                        const FeatureVector *result) const override {
555     static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
556     return this->nested()[0]->Compute(workspaces, object, args..., result);
557   }
558 };
559 
560 // Feature extractor for extracting features from objects of a certain class.
561 // Template type parameters are as defined for FeatureFunction.
562 template <class OBJ, class... ARGS>
563 class FeatureExtractor : public GenericFeatureExtractor {
564  public:
565   // Feature function type for top-level functions in the feature extractor.
566   typedef FeatureFunction<OBJ, ARGS...> Function;
567   typedef FeatureExtractor<OBJ, ARGS...> Self;
568 
569   // Feature locator type for the feature extractor.
570   template <class DER>
571   using Locator = FeatureLocator<DER, OBJ, ARGS...>;
572 
573   // Initializes feature extractor.
FeatureExtractor()574   FeatureExtractor() {}
575 
~FeatureExtractor()576   ~FeatureExtractor() override {
577     // Fully qualified class name, to avoid an ambiguity error when building for
578     // Android.
579     ::libtextclassifier::STLDeleteElements(&functions_);
580   }
581 
582   // Sets up the feature extractor. Note that only top-level functions exist
583   // until Setup() is called. This does not take ownership over the context,
584   // which must outlive this.
Setup(TaskContext * context)585   bool Setup(TaskContext *context) {
586     for (Function *function : functions_) {
587       if (!function->Setup(context)) return false;
588     }
589     return true;
590   }
591 
592   // Initializes the feature extractor.  Must be called after Setup().  This
593   // does not take ownership over the context, which must outlive this.
Init(TaskContext * context)594   bool Init(TaskContext *context) {
595     for (Function *function : functions_) {
596       if (!function->Init(context)) return false;
597     }
598     if (!this->InitializeFeatureTypes()) {
599       return false;
600     }
601     return true;
602   }
603 
604   // Requests workspaces from the registry. Must be called after Init(), and
605   // before Preprocess(). Does not take ownership over registry. This should be
606   // the same registry used to initialize the WorkspaceSet used in Preprocess()
607   // and ExtractFeatures(). NB: This is a different ordering from that used in
608   // SentenceFeatureRepresentation style feature computation.
RequestWorkspaces(WorkspaceRegistry * registry)609   void RequestWorkspaces(WorkspaceRegistry *registry) {
610     for (auto *function : functions_) function->RequestWorkspaces(registry);
611   }
612 
613   // Preprocesses the object using feature functions for the phase.  Must be
614   // called before any calls to ExtractFeatures() on that object and phase.
Preprocess(WorkspaceSet * workspaces,OBJ * object)615   void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {
616     for (Function *function : functions_) {
617       function->Preprocess(workspaces, object);
618     }
619   }
620 
621   // Extracts features from an object with a focus. This invokes all the
622   // top-level feature functions in the feature extractor. Only feature
623   // functions belonging to the specified phase are invoked.
ExtractFeatures(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)624   void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object,
625                        ARGS... args, FeatureVector *result) const {
626     result->reserve(this->feature_types());
627 
628     // Extract features.
629     for (int i = 0; i < functions_.size(); ++i) {
630       functions_[i]->Evaluate(workspaces, object, args..., result);
631     }
632   }
633 
634  private:
635   // Creates and initializes all feature functions in the feature extractor.
InitializeFeatureFunctions()636   bool InitializeFeatureFunctions() override {
637     // Create all top-level feature functions.
638     for (int i = 0; i < descriptor().feature_size(); ++i) {
639       FeatureFunctionDescriptor *fd = mutable_descriptor()->mutable_feature(i);
640       Function *function = Function::Instantiate(this, fd, "");
641       if (function == nullptr) return false;
642       functions_.push_back(function);
643     }
644     return true;
645   }
646 
647   // Collect all feature types used in the feature extractor.
GetFeatureTypes(std::vector<FeatureType * > * types)648   void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
649     for (Function *function : functions_) {
650       function->GetFeatureTypes(types);
651     }
652   }
653 
654   // Top-level feature functions (and variables) in the feature extractor.
655   // Owned.  INVARIANT: contains only non-null pointers.
656   std::vector<Function *> functions_;
657 };
658 
659 #define REGISTER_FEATURE_FUNCTION(base, name, component) \
660   REGISTER_CLASS_COMPONENT(base, name, component)
661 
662 }  // namespace nlp_core
663 }  // namespace libtextclassifier
664 
665 #endif  // LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_
666