1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ 16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ 17 18 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 19 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" 20 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" 21 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" 22 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" 23 24 namespace tensorflow { 25 namespace tensorforest { 26 27 // Abstract base class for classes that can initialize, get, and update leaf 28 // models. 29 class LeafModelOperator { 30 public: 31 // Number of outputs is interpreted differently for classification and 32 // regression. For classification, it's the number of possible classes. 33 // For regression, it's the target dimensions. LeafModelOperator(const TensorForestParams & params)34 explicit LeafModelOperator(const TensorForestParams& params) 35 : params_(params) {} ~LeafModelOperator()36 virtual ~LeafModelOperator() {} 37 38 // Returns the value of the requested output, which should be 39 // in [0, num_outputs_). For classification, it's the class count (weighted 40 // number of instances seen). For regression, it's e.g. the average value. 41 virtual float GetOutputValue(const decision_trees::Leaf& leaf, 42 int32 o) const = 0; 43 44 // Update the given Leaf's model with the given example. 45 virtual void UpdateModel(decision_trees::Leaf* leaf, 46 const InputTarget* target, int example) const = 0; 47 48 // Initialize an empty Leaf model. 49 virtual void InitModel(decision_trees::Leaf* leaf) const = 0; 50 51 virtual void ExportModel(const LeafStat& stat, 52 decision_trees::Leaf* leaf) const = 0; 53 54 protected: 55 const TensorForestParams& params_; 56 }; 57 58 // LeafModelOperator that stores class counts in a dense vector. 59 class DenseClassificationLeafModelOperator : public LeafModelOperator { 60 public: DenseClassificationLeafModelOperator(const TensorForestParams & params)61 explicit DenseClassificationLeafModelOperator( 62 const TensorForestParams& params) 63 : LeafModelOperator(params) {} 64 float GetOutputValue(const decision_trees::Leaf& leaf, 65 int32 o) const override; 66 67 void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target, 68 int example) const override; 69 70 void InitModel(decision_trees::Leaf* leaf) const override; 71 72 void ExportModel(const LeafStat& stat, 73 decision_trees::Leaf* leaf) const override; 74 }; 75 76 // LeafModelOperator that stores class counts sparsely in a map. Assumes default 77 // value for yet-unseen classes is 0. 78 class SparseClassificationLeafModelOperator : public LeafModelOperator { 79 public: SparseClassificationLeafModelOperator(const TensorForestParams & params)80 explicit SparseClassificationLeafModelOperator( 81 const TensorForestParams& params) 82 : LeafModelOperator(params) {} 83 float GetOutputValue(const decision_trees::Leaf& leaf, 84 int32 o) const override; 85 86 void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target, 87 int example) const override; 88 InitModel(decision_trees::Leaf * leaf)89 void InitModel(decision_trees::Leaf* leaf) const override {} 90 91 void ExportModel(const LeafStat& stat, 92 decision_trees::Leaf* leaf) const override; 93 }; 94 95 class SparseOrDenseClassificationLeafModelOperator : public LeafModelOperator { 96 public: SparseOrDenseClassificationLeafModelOperator(const TensorForestParams & params)97 explicit SparseOrDenseClassificationLeafModelOperator( 98 const TensorForestParams& params) 99 : LeafModelOperator(params), 100 dense_(new DenseClassificationLeafModelOperator(params)), 101 sparse_(new SparseClassificationLeafModelOperator(params)) {} 102 float GetOutputValue(const decision_trees::Leaf& leaf, 103 int32 o) const override; 104 105 void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target, 106 int example) const override; 107 InitModel(decision_trees::Leaf * leaf)108 void InitModel(decision_trees::Leaf* leaf) const override {} 109 110 void ExportModel(const LeafStat& stat, 111 decision_trees::Leaf* leaf) const override; 112 113 protected: 114 std::unique_ptr<DenseClassificationLeafModelOperator> dense_; 115 std::unique_ptr<SparseClassificationLeafModelOperator> sparse_; 116 }; 117 118 // LeafModelOperator that stores regression leaf models with constant-value 119 // prediction. 120 class RegressionLeafModelOperator : public LeafModelOperator { 121 public: RegressionLeafModelOperator(const TensorForestParams & params)122 explicit RegressionLeafModelOperator(const TensorForestParams& params) 123 : LeafModelOperator(params) {} 124 float GetOutputValue(const decision_trees::Leaf& leaf, 125 int32 o) const override; 126 127 // TODO(gilberth): Quick experimentation suggests it's not even worth 128 // updating model and just using the seeded values. Can add this in 129 // with additional_data, though protobuf::Any is slow. Maybe make it 130 // optional. Maybe make any update optional. UpdateModel(decision_trees::Leaf * leaf,const InputTarget * target,int example)131 void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target, 132 int example) const override {} 133 134 void InitModel(decision_trees::Leaf* leaf) const override; 135 136 void ExportModel(const LeafStat& stat, 137 decision_trees::Leaf* leaf) const override; 138 }; 139 140 class LeafModelOperatorFactory { 141 public: 142 static std::unique_ptr<LeafModelOperator> CreateLeafModelOperator( 143 const TensorForestParams& params); 144 }; 145 146 } // namespace tensorforest 147 } // namespace tensorflow 148 149 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ 150