1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h"
16 
17 namespace tensorflow {
18 namespace tensorforest {
19 
20 using decision_trees::Leaf;
21 
22 std::unique_ptr<LeafModelOperator>
CreateLeafModelOperator(const TensorForestParams & params)23 LeafModelOperatorFactory::CreateLeafModelOperator(
24     const TensorForestParams& params) {
25   switch (params.leaf_type()) {
26     case MODEL_DENSE_CLASSIFICATION:
27       return std::unique_ptr<LeafModelOperator>(
28           new DenseClassificationLeafModelOperator(params));
29 
30     case MODEL_SPARSE_CLASSIFICATION:
31       return std::unique_ptr<LeafModelOperator>(
32           new SparseClassificationLeafModelOperator(params));
33 
34     case MODEL_SPARSE_OR_DENSE_CLASSIFICATION:
35       return std::unique_ptr<LeafModelOperator>(
36           new SparseOrDenseClassificationLeafModelOperator(params));
37 
38     case MODEL_REGRESSION:
39       return std::unique_ptr<LeafModelOperator>(
40           new RegressionLeafModelOperator(params));
41 
42     default:
43       LOG(ERROR) << "Unknown model operator: " << params.leaf_type();
44       return nullptr;
45   }
46 }
47 
48 // ------------------------ Dense ----------------------------- //
GetOutputValue(const decision_trees::Leaf & leaf,int32 o) const49 float DenseClassificationLeafModelOperator::GetOutputValue(
50     const decision_trees::Leaf& leaf, int32 o) const {
51   return leaf.vector().value(o).float_value();
52 }
53 
UpdateModel(Leaf * leaf,const InputTarget * target,int example) const54 void DenseClassificationLeafModelOperator::UpdateModel(
55     Leaf* leaf, const InputTarget* target, int example) const {
56   const int32 int_label = target->GetTargetAsClassIndex(example, 0);
57   QCHECK_LT(int_label, params_.num_outputs())
58       << "Got label greater than indicated number of classes. Is "
59          "params.num_classes set correctly?";
60   QCHECK_GE(int_label, 0);
61   auto* val = leaf->mutable_vector()->mutable_value(int_label);
62 
63   float weight = target->GetTargetWeight(example);
64   val->set_float_value(val->float_value() + weight);
65 }
66 
InitModel(Leaf * leaf) const67 void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const {
68   for (int i = 0; i < params_.num_outputs(); ++i) {
69     leaf->mutable_vector()->add_value();
70   }
71 }
72 
ExportModel(const LeafStat & stat,decision_trees::Leaf * leaf) const73 void DenseClassificationLeafModelOperator::ExportModel(
74     const LeafStat& stat, decision_trees::Leaf* leaf) const {
75   *leaf->mutable_vector() = stat.classification().dense_counts();
76 }
77 
78 // ------------------------- Sparse -------------------------- //
GetOutputValue(const decision_trees::Leaf & leaf,int32 o) const79 float SparseClassificationLeafModelOperator::GetOutputValue(
80     const decision_trees::Leaf& leaf, int32 o) const {
81   const auto it = leaf.sparse_vector().sparse_value().find(o);
82   if (it == leaf.sparse_vector().sparse_value().end()) {
83     return 0;  // default value
84   } else {
85     return it->second.float_value();
86   }
87 }
88 
UpdateModel(Leaf * leaf,const InputTarget * target,int example) const89 void SparseClassificationLeafModelOperator::UpdateModel(
90     Leaf* leaf, const InputTarget* target, int example) const {
91   const int32 int_label = target->GetTargetAsClassIndex(example, 0);
92   QCHECK_LT(int_label, params_.num_outputs())
93       << "Got label greater than indicated number of classes. Is "
94          "params.num_classes set correctly?";
95   QCHECK_GE(int_label, 0);
96   const float weight = target->GetTargetWeight(example);
97 
98   auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value();
99   auto it = value_map->find(int_label);
100   if (it == value_map->end()) {
101     (*value_map)[int_label].set_float_value(weight);
102   } else {
103     it->second.set_float_value(it->second.float_value() + weight);
104   }
105 }
106 
ExportModel(const LeafStat & stat,decision_trees::Leaf * leaf) const107 void SparseClassificationLeafModelOperator::ExportModel(
108     const LeafStat& stat, decision_trees::Leaf* leaf) const {
109   *leaf->mutable_sparse_vector() = stat.classification().sparse_counts();
110 }
111 
112 // ------------------------- SparseOrDense -------------------------- //
GetOutputValue(const decision_trees::Leaf & leaf,int32 o) const113 float SparseOrDenseClassificationLeafModelOperator::GetOutputValue(
114     const decision_trees::Leaf& leaf, int32 o) const {
115   if (leaf.has_vector()) {
116     return dense_->GetOutputValue(leaf, o);
117   } else {
118     return sparse_->GetOutputValue(leaf, o);
119   }
120 }
121 
UpdateModel(Leaf * leaf,const InputTarget * target,int example) const122 void SparseOrDenseClassificationLeafModelOperator::UpdateModel(
123     Leaf* leaf, const InputTarget* target, int example) const {
124   if (leaf->has_vector()) {
125     return dense_->UpdateModel(leaf, target, example);
126   } else {
127     return sparse_->UpdateModel(leaf, target, example);
128   }
129 }
130 
ExportModel(const LeafStat & stat,decision_trees::Leaf * leaf) const131 void SparseOrDenseClassificationLeafModelOperator::ExportModel(
132     const LeafStat& stat, decision_trees::Leaf* leaf) const {
133   if (stat.classification().has_dense_counts()) {
134     return dense_->ExportModel(stat, leaf);
135   } else {
136     return sparse_->ExportModel(stat, leaf);
137   }
138 }
139 
140 // ------------------------ Regression ----------------------------- //
GetOutputValue(const decision_trees::Leaf & leaf,int32 o) const141 float RegressionLeafModelOperator::GetOutputValue(
142     const decision_trees::Leaf& leaf, int32 o) const {
143   return leaf.vector().value(o).float_value();
144 }
145 
InitModel(Leaf * leaf) const146 void RegressionLeafModelOperator::InitModel(Leaf* leaf) const {
147   for (int i = 0; i < params_.num_outputs(); ++i) {
148     leaf->mutable_vector()->add_value();
149   }
150 }
151 
ExportModel(const LeafStat & stat,decision_trees::Leaf * leaf) const152 void RegressionLeafModelOperator::ExportModel(
153     const LeafStat& stat, decision_trees::Leaf* leaf) const {
154   leaf->clear_vector();
155   for (int i = 0; i < params_.num_outputs(); ++i) {
156     const float new_val =
157         stat.regression().mean_output().value(i).float_value() /
158         stat.weight_sum();
159     leaf->mutable_vector()->add_value()->set_float_value(new_val);
160   }
161 }
162 
163 }  // namespace tensorforest
164 }  // namespace tensorflow
165