1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h"
16
17 namespace tensorflow {
18 namespace tensorforest {
19
20 using decision_trees::DecisionTree;
21 using decision_trees::Leaf;
22 using decision_trees::TreeNode;
23
DecisionTreeResource(const TensorForestParams & params)24 DecisionTreeResource::DecisionTreeResource(const TensorForestParams& params)
25 : params_(params), decision_tree_(new decision_trees::Model()) {
26 model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_);
27 }
28
TraverseTree(const std::unique_ptr<TensorDataSet> & input_data,int example,int32 * leaf_depth,TreePath * path) const29 int32 DecisionTreeResource::TraverseTree(
30 const std::unique_ptr<TensorDataSet>& input_data, int example,
31 int32* leaf_depth, TreePath* path) const {
32 const DecisionTree& tree = decision_tree_->decision_tree();
33 int32 current_id = 0;
34 int32 depth = 0;
35 while (true) {
36 const TreeNode& current = tree.nodes(current_id);
37 if (path != nullptr) {
38 *path->add_nodes_visited() = current;
39 }
40 if (current.has_leaf()) {
41 if (leaf_depth != nullptr) {
42 *leaf_depth = depth;
43 }
44 return current_id;
45 }
46 ++depth;
47 const int32 next_id =
48 node_evaluators_[current_id]->Decide(input_data, example);
49 current_id = tree.nodes(next_id).node_id().value();
50 }
51 }
52
SplitNode(int32 node_id,SplitCandidate * best,std::vector<int32> * new_children)53 void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
54 std::vector<int32>* new_children) {
55 DecisionTree* tree = decision_tree_->mutable_decision_tree();
56 TreeNode* node = tree->mutable_nodes(node_id);
57 int32 newid = tree->nodes_size();
58
59 // left
60 new_children->push_back(newid);
61 TreeNode* new_left = tree->add_nodes();
62 new_left->mutable_node_id()->set_value(newid++);
63 Leaf* left_leaf = new_left->mutable_leaf();
64 model_op_->ExportModel(best->left_stats(), left_leaf);
65
66 // right
67 new_children->push_back(newid);
68 TreeNode* new_right = tree->add_nodes();
69 new_right->mutable_node_id()->set_value(newid);
70 Leaf* right_leaf = new_right->mutable_leaf();
71 model_op_->ExportModel(best->right_stats(), right_leaf);
72
73 node->clear_leaf();
74 node->mutable_binary_node()->Swap(best->mutable_split());
75 node->mutable_binary_node()->mutable_left_child_id()->set_value(newid - 1);
76 node->mutable_binary_node()->mutable_right_child_id()->set_value(newid);
77 while (node_evaluators_.size() <= node_id) {
78 node_evaluators_.emplace_back(nullptr);
79 }
80 node_evaluators_[node_id] = CreateDecisionNodeEvaluator(*node);
81 }
82
MaybeInitialize()83 void DecisionTreeResource::MaybeInitialize() {
84 DecisionTree* tree = decision_tree_->mutable_decision_tree();
85 if (tree->nodes_size() == 0) {
86 model_op_->InitModel(tree->add_nodes()->mutable_leaf());
87 } else if (node_evaluators_.empty()) { // reconstruct evaluators
88 for (const auto& node : tree->nodes()) {
89 if (node.has_leaf()) {
90 node_evaluators_.emplace_back(nullptr);
91 } else {
92 node_evaluators_.push_back(CreateDecisionNodeEvaluator(node));
93 }
94 }
95 }
96 }
97
98 } // namespace tensorforest
99 } // namespace tensorflow
100