1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_
16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_
17 
18 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
19 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
21 #include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h"
22 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 namespace tensorflow {
27 namespace tensorforest {
28 
29 // Keep a tree ensemble in memory for efficient evaluation and mutation.
30 class DecisionTreeResource : public ResourceBase {
31  public:
32   // Constructor.
33   explicit DecisionTreeResource(const TensorForestParams& params);
34 
DebugString()35   string DebugString() const override {
36     return strings::StrCat("DecisionTree[size=",
37                            decision_tree_->decision_tree().nodes_size(), "]");
38   }
39 
40   void MaybeInitialize();
41 
decision_tree()42   const decision_trees::Model& decision_tree() const { return *decision_tree_; }
43 
mutable_decision_tree()44   decision_trees::Model* mutable_decision_tree() {
45     return decision_tree_.get();
46   }
47 
get_leaf(int32 id)48   const decision_trees::Leaf& get_leaf(int32 id) const {
49     return decision_tree_->decision_tree().nodes(id).leaf();
50   }
51 
get_mutable_tree_node(int32 id)52   decision_trees::TreeNode* get_mutable_tree_node(int32 id) {
53     return decision_tree_->mutable_decision_tree()->mutable_nodes(id);
54   }
55 
56   // Resets the resource and frees the proto.
57   // Caller needs to hold the mutex lock while calling this.
Reset()58   void Reset() { decision_tree_.reset(new decision_trees::Model()); }
59 
get_mutex()60   mutex* get_mutex() { return &mu_; }
61 
62   // Return the TreeNode for the leaf that the example ends up at according
63   // to decision_tree_. Also fill in that leaf's depth if it isn't nullptr.
64   int32 TraverseTree(const std::unique_ptr<TensorDataSet>& input_data,
65                      int example, int32* depth, TreePath* path) const;
66 
67   // Split the given node_id, turning it from a Leaf to a BinaryNode and
68   // setting it's split to the given best.  Add new children ids to
69   // new_children.
70   void SplitNode(int32 node_id, SplitCandidate* best,
71                  std::vector<int32>* new_children);
72 
73  private:
74   mutex mu_;
75   const TensorForestParams params_;
76   std::unique_ptr<decision_trees::Model> decision_tree_;
77   std::shared_ptr<LeafModelOperator> model_op_;
78   std::vector<std::unique_ptr<DecisionNodeEvaluator>> node_evaluators_;
79 };
80 
81 }  // namespace tensorforest
82 }  // namespace tensorflow
83 
84 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_
85