1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h"
16
17 namespace tensorflow {
18 namespace tensorforest {
19
20 using decision_trees::Leaf;
21
22 std::unique_ptr<LeafModelOperator>
CreateLeafModelOperator(const TensorForestParams & params)23 LeafModelOperatorFactory::CreateLeafModelOperator(
24 const TensorForestParams& params) {
25 switch (params.leaf_type()) {
26 case MODEL_DENSE_CLASSIFICATION:
27 return std::unique_ptr<LeafModelOperator>(
28 new DenseClassificationLeafModelOperator(params));
29
30 case MODEL_SPARSE_CLASSIFICATION:
31 return std::unique_ptr<LeafModelOperator>(
32 new SparseClassificationLeafModelOperator(params));
33
34 case MODEL_SPARSE_OR_DENSE_CLASSIFICATION:
35 return std::unique_ptr<LeafModelOperator>(
36 new SparseOrDenseClassificationLeafModelOperator(params));
37
38 case MODEL_REGRESSION:
39 return std::unique_ptr<LeafModelOperator>(
40 new RegressionLeafModelOperator(params));
41
42 default:
43 LOG(ERROR) << "Unknown model operator: " << params.leaf_type();
44 return nullptr;
45 }
46 }
47
48 // ------------------------ Dense ----------------------------- //
GetOutputValue(const decision_trees::Leaf & leaf,int32 o) const49 float DenseClassificationLeafModelOperator::GetOutputValue(
50 const decision_trees::Leaf& leaf, int32 o) const {
51 return leaf.vector().value(o).float_value();
52 }
53
UpdateModel(Leaf * leaf,const InputTarget * target,int example) const54 void DenseClassificationLeafModelOperator::UpdateModel(
55 Leaf* leaf, const InputTarget* target, int example) const {
56 const int32 int_label = target->GetTargetAsClassIndex(example, 0);
57 QCHECK_LT(int_label, params_.num_outputs())
58 << "Got label greater than indicated number of classes. Is "
59 "params.num_classes set correctly?";
60 QCHECK_GE(int_label, 0);
61 auto* val = leaf->mutable_vector()->mutable_value(int_label);
62
63 float weight = target->GetTargetWeight(example);
64 val->set_float_value(val->float_value() + weight);
65 }
66
InitModel(Leaf * leaf) const67 void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const {
68 for (int i = 0; i < params_.num_outputs(); ++i) {
69 leaf->mutable_vector()->add_value();
70 }
71 }
72
ExportModel(const LeafStat & stat,decision_trees::Leaf * leaf) const73 void DenseClassificationLeafModelOperator::ExportModel(
74 const LeafStat& stat, decision_trees::Leaf* leaf) const {
75 *leaf->mutable_vector() = stat.classification().dense_counts();
76 }
77
78 // ------------------------- Sparse -------------------------- //
GetOutputValue(const decision_trees::Leaf & leaf,int32 o) const79 float SparseClassificationLeafModelOperator::GetOutputValue(
80 const decision_trees::Leaf& leaf, int32 o) const {
81 const auto it = leaf.sparse_vector().sparse_value().find(o);
82 if (it == leaf.sparse_vector().sparse_value().end()) {
83 return 0; // default value
84 } else {
85 return it->second.float_value();
86 }
87 }
88
UpdateModel(Leaf * leaf,const InputTarget * target,int example) const89 void SparseClassificationLeafModelOperator::UpdateModel(
90 Leaf* leaf, const InputTarget* target, int example) const {
91 const int32 int_label = target->GetTargetAsClassIndex(example, 0);
92 QCHECK_LT(int_label, params_.num_outputs())
93 << "Got label greater than indicated number of classes. Is "
94 "params.num_classes set correctly?";
95 QCHECK_GE(int_label, 0);
96 const float weight = target->GetTargetWeight(example);
97
98 auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value();
99 auto it = value_map->find(int_label);
100 if (it == value_map->end()) {
101 (*value_map)[int_label].set_float_value(weight);
102 } else {
103 it->second.set_float_value(it->second.float_value() + weight);
104 }
105 }
106
ExportModel(const LeafStat & stat,decision_trees::Leaf * leaf) const107 void SparseClassificationLeafModelOperator::ExportModel(
108 const LeafStat& stat, decision_trees::Leaf* leaf) const {
109 *leaf->mutable_sparse_vector() = stat.classification().sparse_counts();
110 }
111
112 // ------------------------- SparseOrDense -------------------------- //
GetOutputValue(const decision_trees::Leaf & leaf,int32 o) const113 float SparseOrDenseClassificationLeafModelOperator::GetOutputValue(
114 const decision_trees::Leaf& leaf, int32 o) const {
115 if (leaf.has_vector()) {
116 return dense_->GetOutputValue(leaf, o);
117 } else {
118 return sparse_->GetOutputValue(leaf, o);
119 }
120 }
121
UpdateModel(Leaf * leaf,const InputTarget * target,int example) const122 void SparseOrDenseClassificationLeafModelOperator::UpdateModel(
123 Leaf* leaf, const InputTarget* target, int example) const {
124 if (leaf->has_vector()) {
125 return dense_->UpdateModel(leaf, target, example);
126 } else {
127 return sparse_->UpdateModel(leaf, target, example);
128 }
129 }
130
ExportModel(const LeafStat & stat,decision_trees::Leaf * leaf) const131 void SparseOrDenseClassificationLeafModelOperator::ExportModel(
132 const LeafStat& stat, decision_trees::Leaf* leaf) const {
133 if (stat.classification().has_dense_counts()) {
134 return dense_->ExportModel(stat, leaf);
135 } else {
136 return sparse_->ExportModel(stat, leaf);
137 }
138 }
139
140 // ------------------------ Regression ----------------------------- //
GetOutputValue(const decision_trees::Leaf & leaf,int32 o) const141 float RegressionLeafModelOperator::GetOutputValue(
142 const decision_trees::Leaf& leaf, int32 o) const {
143 return leaf.vector().value(o).float_value();
144 }
145
InitModel(Leaf * leaf) const146 void RegressionLeafModelOperator::InitModel(Leaf* leaf) const {
147 for (int i = 0; i < params_.num_outputs(); ++i) {
148 leaf->mutable_vector()->add_value();
149 }
150 }
151
ExportModel(const LeafStat & stat,decision_trees::Leaf * leaf) const152 void RegressionLeafModelOperator::ExportModel(
153 const LeafStat& stat, decision_trees::Leaf* leaf) const {
154 leaf->clear_vector();
155 for (int i = 0; i < params_.num_outputs(); ++i) {
156 const float new_val =
157 stat.regression().mean_output().value(i).float_value() /
158 stat.weight_sum();
159 leaf->mutable_vector()->add_value()->set_float_value(new_val);
160 }
161 }
162
163 } // namespace tensorforest
164 } // namespace tensorflow
165