1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h"
16
17 #include <cfloat>
18
19 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
21 #include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
22
23 namespace tensorflow {
24 namespace tensorforest {
25
26 std::unordered_map<int, CollectionCreator*>
27 SplitCollectionOperatorFactory::factories_; // NOLINT
28 REGISTER_SPLIT_COLLECTION(COLLECTION_BASIC, SplitCollectionOperator);
29
30 std::unique_ptr<SplitCollectionOperator>
CreateSplitCollectionOperator(const TensorForestParams & params)31 SplitCollectionOperatorFactory::CreateSplitCollectionOperator(
32 const TensorForestParams& params) {
33 auto it = factories_.find(params.collection_type());
34 if (it == factories_.end()) {
35 LOG(ERROR) << "Unknown split collection operator: "
36 << params.collection_type();
37 return nullptr;
38 } else {
39 return it->second->Create(params);
40 }
41 }
42
CreateGrowStats(int32 node_id,int32 depth) const43 std::unique_ptr<GrowStats> SplitCollectionOperator::CreateGrowStats(
44 int32 node_id, int32 depth) const {
45 switch (params_.stats_type()) {
46 case STATS_DENSE_GINI:
47 return std::unique_ptr<GrowStats>(
48 new DenseClassificationGrowStats(params_, depth));
49
50 case STATS_SPARSE_GINI:
51 return std::unique_ptr<GrowStats>(
52 new SparseClassificationGrowStats(params_, depth));
53
54 case STATS_LEAST_SQUARES_REGRESSION:
55 return std::unique_ptr<GrowStats>(
56 new LeastSquaresRegressionGrowStats(params_, depth));
57
58 case STATS_FIXED_SIZE_SPARSE_GINI:
59 return std::unique_ptr<GrowStats>(
60 new FixedSizeSparseClassificationGrowStats(params_, depth));
61
62 default:
63 LOG(ERROR) << "Unknown grow stats type: " << params_.stats_type();
64 return nullptr;
65 }
66 }
67
ExtractFromProto(const FertileStats & stats_proto)68 void SplitCollectionOperator::ExtractFromProto(
69 const FertileStats& stats_proto) {
70 for (int i = 0; i < stats_proto.node_to_slot_size(); ++i) {
71 const auto& slot = stats_proto.node_to_slot(i);
72 stats_[slot.node_id()] = CreateGrowStats(slot.node_id(), slot.depth());
73 stats_[slot.node_id()]->ExtractFromProto(slot);
74 }
75 }
76
PackToProto(FertileStats * stats_proto) const77 void SplitCollectionOperator::PackToProto(FertileStats* stats_proto) const {
78 for (const auto& pair : stats_) {
79 auto* new_slot = stats_proto->add_node_to_slot();
80 new_slot->set_node_id(pair.first);
81 if (params_.checkpoint_stats()) {
82 pair.second->PackToProto(new_slot);
83 }
84 new_slot->set_depth(pair.second->depth());
85 }
86 }
87
InitializeSlot(int32 node_id,int32 depth)88 void SplitCollectionOperator::InitializeSlot(int32 node_id, int32 depth) {
89 stats_[node_id] = std::unique_ptr<GrowStats>(CreateGrowStats(node_id, depth));
90 stats_[node_id]->Initialize();
91 }
92
AddExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,const std::vector<int> & examples,int32 node_id) const93 void SplitCollectionOperator::AddExample(
94 const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
95 const std::vector<int>& examples, int32 node_id) const {
96 auto* slot = stats_.at(node_id).get();
97 for (int example : examples) {
98 slot->AddExample(input_data, target, example);
99 }
100 }
101
IsInitialized(int32 node_id) const102 bool SplitCollectionOperator::IsInitialized(int32 node_id) const {
103 auto it = stats_.find(node_id);
104 if (it == stats_.end()) {
105 LOG(WARNING) << "IsInitialized called with unknown node_id = " << node_id;
106 return false;
107 }
108 return it->second->IsInitialized();
109 }
110
CreateAndInitializeCandidateWithExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example,int32 node_id) const111 void SplitCollectionOperator::CreateAndInitializeCandidateWithExample(
112 const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
113 int example, int32 node_id) const {
114 // Assumes split_initializations_per_input == 1.
115 decision_trees::BinaryNode split;
116 float bias;
117 int type;
118 decision_trees::FeatureId feature_id;
119 input_data->RandomSample(example, &feature_id, &bias, &type);
120
121 if (type == kDataFloat) {
122 decision_trees::InequalityTest* test =
123 split.mutable_inequality_left_child_test();
124 *test->mutable_feature_id() = feature_id;
125 test->mutable_threshold()->set_float_value(bias);
126 test->set_type(params_.inequality_test_type());
127 } else if (type == kDataCategorical) {
128 decision_trees::MatchingValuesTest test;
129 *test.mutable_feature_id() = feature_id;
130 test.add_value()->set_float_value(bias);
131 split.mutable_custom_left_child_test()->PackFrom(test);
132 } else {
133 LOG(ERROR) << "Unknown feature type " << type << ", not sure which "
134 << "node type to use.";
135 }
136 stats_.at(node_id)->AddSplit(split, input_data, target, example);
137 }
138
BestSplit(int32 node_id,SplitCandidate * best,int32 * depth) const139 bool SplitCollectionOperator::BestSplit(int32 node_id, SplitCandidate* best,
140 int32* depth) const {
141 auto* slot = stats_.at(node_id).get();
142 *depth = slot->depth();
143 return slot->BestSplit(best);
144 }
145 } // namespace tensorforest
146 } // namespace tensorflow
147