1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ 16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ 17 18 #include <vector> 19 20 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" 21 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" 22 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" 23 #include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" 24 #include "tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h" 25 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" 26 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" 27 #include "tensorflow/core/framework/resource_mgr.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/platform/mutex.h" 30 31 namespace tensorflow { 32 namespace tensorforest { 33 34 // Stores a FertileStats proto and implements operations on it. 35 class FertileStatsResource : public ResourceBase { 36 public: 37 // Constructor. FertileStatsResource(const TensorForestParams & params)38 explicit FertileStatsResource(const TensorForestParams& params) 39 : params_(params) { 40 model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_); 41 } 42 DebugString()43 string DebugString() const override { return "FertileStats"; } 44 45 void ExtractFromProto(const FertileStats& stats); 46 47 void PackToProto(FertileStats* stats) const; 48 49 // Resets the resource and frees the proto. 50 // Caller needs to hold the mutex lock while calling this. Reset()51 void Reset() {} 52 53 // Reset the stats for a node, but leave the leaf_stats intact. ResetSplitStats(int32 node_id,int32 depth)54 void ResetSplitStats(int32 node_id, int32 depth) { 55 collection_op_->ClearSlot(node_id); 56 collection_op_->InitializeSlot(node_id, depth); 57 } 58 get_mutex()59 mutex* get_mutex() { return &mu_; } 60 61 void MaybeInitialize(); 62 63 // Applies the example to the given leaf's statistics. Also applies it to the 64 // node's fertile slot's statistics if or initializes a split candidate, 65 // where applicable. Returns if the node is finished or if it's ready to 66 // allocate to a fertile slot. 67 void AddExampleToStatsAndInitialize( 68 const std::unique_ptr<TensorDataSet>& input_data, 69 const InputTarget* target, const std::vector<int>& examples, 70 int32 node_id, bool* is_finished); 71 72 // Allocate a fertile slot for each ready node, then new children up to 73 // max_fertile_nodes_. 74 void Allocate(int32 parent_depth, const std::vector<int32>& new_children); 75 76 // Remove a node's fertile slot. Should only be called when the node is 77 // no longer a leaf. 78 void Clear(int32 node); 79 80 // Return the best SplitCandidate for a node, or NULL if no suitable split 81 // was found. 82 bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth); 83 84 private: 85 mutex mu_; 86 std::shared_ptr<LeafModelOperator> model_op_; 87 std::unique_ptr<SplitCollectionOperator> collection_op_; 88 const TensorForestParams params_; 89 90 void AllocateNode(int32 node_id, int32 depth); 91 }; 92 93 } // namespace tensorforest 94 } // namespace tensorflow 95 96 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ 97