1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
17 
18 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
19 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
21 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
22 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
23 
24 namespace tensorflow {
25 namespace tensorforest {
26 
27 // Abstract base class for classes that can initialize, get, and update leaf
28 // models.
29 class LeafModelOperator {
30  public:
31   // Number of outputs is interpreted differently for classification and
32   // regression.  For classification, it's the number of possible classes.
33   // For regression, it's the target dimensions.
LeafModelOperator(const TensorForestParams & params)34   explicit LeafModelOperator(const TensorForestParams& params)
35       : params_(params) {}
~LeafModelOperator()36   virtual ~LeafModelOperator() {}
37 
38   // Returns the value of the requested output, which should be
39   // in [0, num_outputs_).  For classification, it's the class count (weighted
40   // number of instances seen).  For regression, it's e.g. the average value.
41   virtual float GetOutputValue(const decision_trees::Leaf& leaf,
42                                int32 o) const = 0;
43 
44   // Update the given Leaf's model with the given example.
45   virtual void UpdateModel(decision_trees::Leaf* leaf,
46                            const InputTarget* target, int example) const = 0;
47 
48   // Initialize an empty Leaf model.
49   virtual void InitModel(decision_trees::Leaf* leaf) const = 0;
50 
51   virtual void ExportModel(const LeafStat& stat,
52                            decision_trees::Leaf* leaf) const = 0;
53 
54  protected:
55   const TensorForestParams& params_;
56 };
57 
58 // LeafModelOperator that stores class counts in a dense vector.
59 class DenseClassificationLeafModelOperator : public LeafModelOperator {
60  public:
DenseClassificationLeafModelOperator(const TensorForestParams & params)61   explicit DenseClassificationLeafModelOperator(
62       const TensorForestParams& params)
63       : LeafModelOperator(params) {}
64   float GetOutputValue(const decision_trees::Leaf& leaf,
65                        int32 o) const override;
66 
67   void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
68                    int example) const override;
69 
70   void InitModel(decision_trees::Leaf* leaf) const override;
71 
72   void ExportModel(const LeafStat& stat,
73                    decision_trees::Leaf* leaf) const override;
74 };
75 
76 // LeafModelOperator that stores class counts sparsely in a map. Assumes default
77 // value for yet-unseen classes is 0.
78 class SparseClassificationLeafModelOperator : public LeafModelOperator {
79  public:
SparseClassificationLeafModelOperator(const TensorForestParams & params)80   explicit SparseClassificationLeafModelOperator(
81       const TensorForestParams& params)
82       : LeafModelOperator(params) {}
83   float GetOutputValue(const decision_trees::Leaf& leaf,
84                        int32 o) const override;
85 
86   void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
87                    int example) const override;
88 
InitModel(decision_trees::Leaf * leaf)89   void InitModel(decision_trees::Leaf* leaf) const override {}
90 
91   void ExportModel(const LeafStat& stat,
92                    decision_trees::Leaf* leaf) const override;
93 };
94 
95 class SparseOrDenseClassificationLeafModelOperator : public LeafModelOperator {
96  public:
SparseOrDenseClassificationLeafModelOperator(const TensorForestParams & params)97   explicit SparseOrDenseClassificationLeafModelOperator(
98       const TensorForestParams& params)
99       : LeafModelOperator(params),
100         dense_(new DenseClassificationLeafModelOperator(params)),
101         sparse_(new SparseClassificationLeafModelOperator(params)) {}
102   float GetOutputValue(const decision_trees::Leaf& leaf,
103                        int32 o) const override;
104 
105   void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
106                    int example) const override;
107 
InitModel(decision_trees::Leaf * leaf)108   void InitModel(decision_trees::Leaf* leaf) const override {}
109 
110   void ExportModel(const LeafStat& stat,
111                    decision_trees::Leaf* leaf) const override;
112 
113  protected:
114   std::unique_ptr<DenseClassificationLeafModelOperator> dense_;
115   std::unique_ptr<SparseClassificationLeafModelOperator> sparse_;
116 };
117 
118 // LeafModelOperator that stores regression leaf models with constant-value
119 // prediction.
120 class RegressionLeafModelOperator : public LeafModelOperator {
121  public:
RegressionLeafModelOperator(const TensorForestParams & params)122   explicit RegressionLeafModelOperator(const TensorForestParams& params)
123       : LeafModelOperator(params) {}
124   float GetOutputValue(const decision_trees::Leaf& leaf,
125                        int32 o) const override;
126 
127   // TODO(gilberth): Quick experimentation suggests it's not even worth
128   // updating model and just using the seeded values.  Can add this in
129   // with additional_data, though protobuf::Any is slow.  Maybe make it
130   // optional.  Maybe make any update optional.
UpdateModel(decision_trees::Leaf * leaf,const InputTarget * target,int example)131   void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
132                    int example) const override {}
133 
134   void InitModel(decision_trees::Leaf* leaf) const override;
135 
136   void ExportModel(const LeafStat& stat,
137                    decision_trees::Leaf* leaf) const override;
138 };
139 
140 class LeafModelOperatorFactory {
141  public:
142   static std::unique_ptr<LeafModelOperator> CreateLeafModelOperator(
143       const TensorForestParams& params);
144 };
145 
146 }  // namespace tensorforest
147 }  // namespace tensorflow
148 
149 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
150