1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // This file contains the MulticlassPA class which implements a simple 18 // linear multi-class classifier based on the multi-prototype version of 19 // passive aggressive. 20 21 #ifndef LEARNINGFW_MULTICLASS_PA_H_ 22 #define LEARNINGFW_MULTICLASS_PA_H_ 23 24 #include <vector> 25 #include <cmath> 26 27 const float kEpsilon = 1.0e-4; 28 29 namespace learningfw { 30 31 class MulticlassPA { 32 public: 33 MulticlassPA(int num_classes, 34 int num_dimensions, 35 float aggressiveness); 36 virtual ~MulticlassPA(); 37 38 // Initialize all parameters to 0.0. 39 void InitializeParameters(); 40 41 // Returns a random class that is different from the target class. 42 int PickAClassExcept(int target); 43 44 // Returns a random example. 45 int PickAnExample(int num_examples); 46 47 // Computes the score of a given input vector for a given parameter 48 // vector, by computing the dot product between the two. 49 float Score(const std::vector<float>& inputs, 50 const std::vector<float>& parameters) const; 51 float SparseScore(const std::vector<std::pair<int, float> >& inputs, 52 const std::vector<float>& parameters) const; 53 54 // Returns the square of the L2 norm. 55 float L2NormSquare(const std::vector<float>& inputs) const; 56 float SparseL2NormSquare(const std::vector<std::pair<int, float> >& inputs) const; 57 58 // Verify if the given example is correctly classified with margin with 59 // respect to a random class. If not, then modifies the corresponding 60 // parameters using passive-aggressive. 61 virtual float TrainOneExample(const std::vector<float>& inputs, int target); 62 virtual float SparseTrainOneExample( 63 const std::vector<std::pair<int, float> >& inputs, int target); 64 65 // Iteratively train the model for num_iterations on the given dataset. 66 float Train(const std::vector<std::pair<std::vector<float>, int> >& data, 67 int num_iterations); 68 float SparseTrain( 69 const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data, 70 int num_iterations); 71 72 // Returns the best class for a given input vector. 73 virtual int GetClass(const std::vector<float>& inputs); 74 virtual int SparseGetClass(const std::vector<std::pair<int, float> >& inputs); 75 76 // Computes the test error of a given test set on the current model. 77 float Test(const std::vector<std::pair<std::vector<float>, int> >& data); 78 float SparseTest( 79 const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data); 80 81 // A few accessors used by the sub-classes. aggressiveness()82 inline float aggressiveness() const { 83 return aggressiveness_; 84 } 85 parameters()86 inline std::vector<std::vector<float> >& parameters() { 87 return parameters_; 88 } 89 mutable_parameters()90 inline std::vector<std::vector<float> >* mutable_parameters() { 91 return ¶meters_; 92 } 93 num_classes()94 inline int num_classes() const { 95 return num_classes_; 96 } 97 num_dimensions()98 inline int num_dimensions() const { 99 return num_dimensions_; 100 } 101 102 private: 103 // Keeps the current parameter vector. 104 std::vector<std::vector<float> > parameters_; 105 106 // The number of classes of the problem. 107 int num_classes_; 108 109 // The number of dimensions of the input vectors. 110 int num_dimensions_; 111 112 // Controls how "aggressive" training should be. 113 float aggressiveness_; 114 115 }; 116 } // namespace learningfw 117 #endif // LEARNINGFW_MULTICLASS_PA_H_ 118