1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h"
16 
17 #include <cfloat>
18 
19 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
21 #include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
22 
23 namespace tensorflow {
24 namespace tensorforest {
25 
26 std::unordered_map<int, CollectionCreator*>
27     SplitCollectionOperatorFactory::factories_;  // NOLINT
28 REGISTER_SPLIT_COLLECTION(COLLECTION_BASIC, SplitCollectionOperator);
29 
30 std::unique_ptr<SplitCollectionOperator>
CreateSplitCollectionOperator(const TensorForestParams & params)31 SplitCollectionOperatorFactory::CreateSplitCollectionOperator(
32     const TensorForestParams& params) {
33   auto it = factories_.find(params.collection_type());
34   if (it == factories_.end()) {
35     LOG(ERROR) << "Unknown split collection operator: "
36                << params.collection_type();
37     return nullptr;
38   } else {
39     return it->second->Create(params);
40   }
41 }
42 
CreateGrowStats(int32 node_id,int32 depth) const43 std::unique_ptr<GrowStats> SplitCollectionOperator::CreateGrowStats(
44     int32 node_id, int32 depth) const {
45   switch (params_.stats_type()) {
46     case STATS_DENSE_GINI:
47       return std::unique_ptr<GrowStats>(
48           new DenseClassificationGrowStats(params_, depth));
49 
50     case STATS_SPARSE_GINI:
51       return std::unique_ptr<GrowStats>(
52           new SparseClassificationGrowStats(params_, depth));
53 
54     case STATS_LEAST_SQUARES_REGRESSION:
55       return std::unique_ptr<GrowStats>(
56           new LeastSquaresRegressionGrowStats(params_, depth));
57 
58     case STATS_FIXED_SIZE_SPARSE_GINI:
59       return std::unique_ptr<GrowStats>(
60           new FixedSizeSparseClassificationGrowStats(params_, depth));
61 
62     default:
63       LOG(ERROR) << "Unknown grow stats type: " << params_.stats_type();
64       return nullptr;
65   }
66 }
67 
ExtractFromProto(const FertileStats & stats_proto)68 void SplitCollectionOperator::ExtractFromProto(
69     const FertileStats& stats_proto) {
70   for (int i = 0; i < stats_proto.node_to_slot_size(); ++i) {
71     const auto& slot = stats_proto.node_to_slot(i);
72     stats_[slot.node_id()] = CreateGrowStats(slot.node_id(), slot.depth());
73     stats_[slot.node_id()]->ExtractFromProto(slot);
74   }
75 }
76 
PackToProto(FertileStats * stats_proto) const77 void SplitCollectionOperator::PackToProto(FertileStats* stats_proto) const {
78   for (const auto& pair : stats_) {
79     auto* new_slot = stats_proto->add_node_to_slot();
80     new_slot->set_node_id(pair.first);
81     if (params_.checkpoint_stats()) {
82       pair.second->PackToProto(new_slot);
83     }
84     new_slot->set_depth(pair.second->depth());
85   }
86 }
87 
InitializeSlot(int32 node_id,int32 depth)88 void SplitCollectionOperator::InitializeSlot(int32 node_id, int32 depth) {
89   stats_[node_id] = std::unique_ptr<GrowStats>(CreateGrowStats(node_id, depth));
90   stats_[node_id]->Initialize();
91 }
92 
AddExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,const std::vector<int> & examples,int32 node_id) const93 void SplitCollectionOperator::AddExample(
94     const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
95     const std::vector<int>& examples, int32 node_id) const {
96   auto* slot = stats_.at(node_id).get();
97   for (int example : examples) {
98     slot->AddExample(input_data, target, example);
99   }
100 }
101 
IsInitialized(int32 node_id) const102 bool SplitCollectionOperator::IsInitialized(int32 node_id) const {
103   auto it = stats_.find(node_id);
104   if (it == stats_.end()) {
105     LOG(WARNING) << "IsInitialized called with unknown node_id = " << node_id;
106     return false;
107   }
108   return it->second->IsInitialized();
109 }
110 
CreateAndInitializeCandidateWithExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example,int32 node_id) const111 void SplitCollectionOperator::CreateAndInitializeCandidateWithExample(
112     const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
113     int example, int32 node_id) const {
114   // Assumes split_initializations_per_input == 1.
115   decision_trees::BinaryNode split;
116   float bias;
117   int type;
118   decision_trees::FeatureId feature_id;
119   input_data->RandomSample(example, &feature_id, &bias, &type);
120 
121   if (type == kDataFloat) {
122     decision_trees::InequalityTest* test =
123         split.mutable_inequality_left_child_test();
124     *test->mutable_feature_id() = feature_id;
125     test->mutable_threshold()->set_float_value(bias);
126     test->set_type(params_.inequality_test_type());
127   } else if (type == kDataCategorical) {
128     decision_trees::MatchingValuesTest test;
129     *test.mutable_feature_id() = feature_id;
130     test.add_value()->set_float_value(bias);
131     split.mutable_custom_left_child_test()->PackFrom(test);
132   } else {
133     LOG(ERROR) << "Unknown feature type " << type << ", not sure which "
134                << "node type to use.";
135   }
136   stats_.at(node_id)->AddSplit(split, input_data, target, example);
137 }
138 
BestSplit(int32 node_id,SplitCandidate * best,int32 * depth) const139 bool SplitCollectionOperator::BestSplit(int32 node_id, SplitCandidate* best,
140                                         int32* depth) const {
141   auto* slot = stats_.at(node_id).get();
142   *depth = slot->depth();
143   return slot->BestSplit(best);
144 }
145 }  // namespace tensorforest
146 }  // namespace tensorflow
147