1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_
16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
21 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
22 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
23 #include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h"
24 #include "tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h"
25 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
26 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/platform/mutex.h"
30 
31 namespace tensorflow {
32 namespace tensorforest {
33 
34 // Stores a FertileStats proto and implements operations on it.
35 class FertileStatsResource : public ResourceBase {
36  public:
37   // Constructor.
FertileStatsResource(const TensorForestParams & params)38   explicit FertileStatsResource(const TensorForestParams& params)
39       : params_(params) {
40     model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_);
41   }
42 
DebugString()43   string DebugString() const override { return "FertileStats"; }
44 
45   void ExtractFromProto(const FertileStats& stats);
46 
47   void PackToProto(FertileStats* stats) const;
48 
49   // Resets the resource and frees the proto.
50   // Caller needs to hold the mutex lock while calling this.
Reset()51   void Reset() {}
52 
53   // Reset the stats for a node, but leave the leaf_stats intact.
ResetSplitStats(int32 node_id,int32 depth)54   void ResetSplitStats(int32 node_id, int32 depth) {
55     collection_op_->ClearSlot(node_id);
56     collection_op_->InitializeSlot(node_id, depth);
57   }
58 
get_mutex()59   mutex* get_mutex() { return &mu_; }
60 
61   void MaybeInitialize();
62 
63   // Applies the example to the given leaf's statistics. Also applies it to the
64   // node's fertile slot's statistics if or initializes a split candidate,
65   // where applicable.  Returns if the node is finished or if it's ready to
66   // allocate to a fertile slot.
67   void AddExampleToStatsAndInitialize(
68       const std::unique_ptr<TensorDataSet>& input_data,
69       const InputTarget* target, const std::vector<int>& examples,
70       int32 node_id, bool* is_finished);
71 
72   // Allocate a fertile slot for each ready node, then new children up to
73   // max_fertile_nodes_.
74   void Allocate(int32 parent_depth, const std::vector<int32>& new_children);
75 
76   // Remove a node's fertile slot.  Should only be called when the node is
77   // no longer a leaf.
78   void Clear(int32 node);
79 
80   // Return the best SplitCandidate for a node, or NULL if no suitable split
81   // was found.
82   bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth);
83 
84  private:
85   mutex mu_;
86   std::shared_ptr<LeafModelOperator> model_op_;
87   std::unique_ptr<SplitCollectionOperator> collection_op_;
88   const TensorForestParams params_;
89 
90   void AllocateNode(int32 node_id, int32 depth);
91 };
92 
93 }  // namespace tensorforest
94 }  // namespace tensorflow
95 
96 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_
97