1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h"
16 
17 namespace tensorflow {
18 namespace tensorforest {
19 
20 using decision_trees::DecisionTree;
21 using decision_trees::Leaf;
22 using decision_trees::TreeNode;
23 
DecisionTreeResource(const TensorForestParams & params)24 DecisionTreeResource::DecisionTreeResource(const TensorForestParams& params)
25     : params_(params), decision_tree_(new decision_trees::Model()) {
26   model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_);
27 }
28 
TraverseTree(const std::unique_ptr<TensorDataSet> & input_data,int example,int32 * leaf_depth,TreePath * path) const29 int32 DecisionTreeResource::TraverseTree(
30     const std::unique_ptr<TensorDataSet>& input_data, int example,
31     int32* leaf_depth, TreePath* path) const {
32   const DecisionTree& tree = decision_tree_->decision_tree();
33   int32 current_id = 0;
34   int32 depth = 0;
35   while (true) {
36     const TreeNode& current = tree.nodes(current_id);
37     if (path != nullptr) {
38       *path->add_nodes_visited() = current;
39     }
40     if (current.has_leaf()) {
41       if (leaf_depth != nullptr) {
42         *leaf_depth = depth;
43       }
44       return current_id;
45     }
46     ++depth;
47     const int32 next_id =
48         node_evaluators_[current_id]->Decide(input_data, example);
49     current_id = tree.nodes(next_id).node_id().value();
50   }
51 }
52 
SplitNode(int32 node_id,SplitCandidate * best,std::vector<int32> * new_children)53 void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
54                                      std::vector<int32>* new_children) {
55   DecisionTree* tree = decision_tree_->mutable_decision_tree();
56   TreeNode* node = tree->mutable_nodes(node_id);
57   int32 newid = tree->nodes_size();
58 
59   // left
60   new_children->push_back(newid);
61   TreeNode* new_left = tree->add_nodes();
62   new_left->mutable_node_id()->set_value(newid++);
63   Leaf* left_leaf = new_left->mutable_leaf();
64   model_op_->ExportModel(best->left_stats(), left_leaf);
65 
66   // right
67   new_children->push_back(newid);
68   TreeNode* new_right = tree->add_nodes();
69   new_right->mutable_node_id()->set_value(newid);
70   Leaf* right_leaf = new_right->mutable_leaf();
71   model_op_->ExportModel(best->right_stats(), right_leaf);
72 
73   node->clear_leaf();
74   node->mutable_binary_node()->Swap(best->mutable_split());
75   node->mutable_binary_node()->mutable_left_child_id()->set_value(newid - 1);
76   node->mutable_binary_node()->mutable_right_child_id()->set_value(newid);
77   while (node_evaluators_.size() <= node_id) {
78     node_evaluators_.emplace_back(nullptr);
79   }
80   node_evaluators_[node_id] = CreateDecisionNodeEvaluator(*node);
81 }
82 
MaybeInitialize()83 void DecisionTreeResource::MaybeInitialize() {
84   DecisionTree* tree = decision_tree_->mutable_decision_tree();
85   if (tree->nodes_size() == 0) {
86     model_op_->InitModel(tree->add_nodes()->mutable_leaf());
87   } else if (node_evaluators_.empty()) {  // reconstruct evaluators
88     for (const auto& node : tree->nodes()) {
89       if (node.has_leaf()) {
90         node_evaluators_.emplace_back(nullptr);
91       } else {
92         node_evaluators_.push_back(CreateDecisionNodeEvaluator(node));
93       }
94     }
95   }
96 }
97 
98 }  // namespace tensorforest
99 }  // namespace tensorflow
100