1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ 16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ 17 18 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 19 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" 20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" 21 #include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" 22 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" 23 #include "tensorflow/core/framework/resource_mgr.h" 24 #include "tensorflow/core/platform/mutex.h" 25 26 namespace tensorflow { 27 namespace tensorforest { 28 29 // Keep a tree ensemble in memory for efficient evaluation and mutation. 30 class DecisionTreeResource : public ResourceBase { 31 public: 32 // Constructor. 33 explicit DecisionTreeResource(const TensorForestParams& params); 34 DebugString()35 string DebugString() const override { 36 return strings::StrCat("DecisionTree[size=", 37 decision_tree_->decision_tree().nodes_size(), "]"); 38 } 39 40 void MaybeInitialize(); 41 decision_tree()42 const decision_trees::Model& decision_tree() const { return *decision_tree_; } 43 mutable_decision_tree()44 decision_trees::Model* mutable_decision_tree() { 45 return decision_tree_.get(); 46 } 47 get_leaf(int32 id)48 const decision_trees::Leaf& get_leaf(int32 id) const { 49 return decision_tree_->decision_tree().nodes(id).leaf(); 50 } 51 get_mutable_tree_node(int32 id)52 decision_trees::TreeNode* get_mutable_tree_node(int32 id) { 53 return decision_tree_->mutable_decision_tree()->mutable_nodes(id); 54 } 55 56 // Resets the resource and frees the proto. 57 // Caller needs to hold the mutex lock while calling this. Reset()58 void Reset() { decision_tree_.reset(new decision_trees::Model()); } 59 get_mutex()60 mutex* get_mutex() { return &mu_; } 61 62 // Return the TreeNode for the leaf that the example ends up at according 63 // to decision_tree_. Also fill in that leaf's depth if it isn't nullptr. 64 int32 TraverseTree(const std::unique_ptr<TensorDataSet>& input_data, 65 int example, int32* depth, TreePath* path) const; 66 67 // Split the given node_id, turning it from a Leaf to a BinaryNode and 68 // setting it's split to the given best. Add new children ids to 69 // new_children. 70 void SplitNode(int32 node_id, SplitCandidate* best, 71 std::vector<int32>* new_children); 72 73 private: 74 mutex mu_; 75 const TensorForestParams params_; 76 std::unique_ptr<decision_trees::Model> decision_tree_; 77 std::shared_ptr<LeafModelOperator> model_op_; 78 std::vector<std::unique_ptr<DecisionNodeEvaluator>> node_evaluators_; 79 }; 80 81 } // namespace tensorforest 82 } // namespace tensorflow 83 84 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ 85