1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_
16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_
17 
18 #include <vector>
19 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h"
21 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
22 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
23 #include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h"
24 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
25 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
26 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
27 
28 namespace tensorflow {
29 namespace tensorforest {
30 
31 // Class that can initialize and update split collections, and
32 // report if one is finished and ready to split.  Designed to be inherited
33 // from to implement techniques such as pruning and early/delayed finishing.
34 class SplitCollectionOperator {
35  public:
SplitCollectionOperator(const TensorForestParams & params)36   explicit SplitCollectionOperator(const TensorForestParams& params)
37       : params_(params) {}
~SplitCollectionOperator()38   virtual ~SplitCollectionOperator() {}
39 
40   // Return a new GrowStats object according to stats_type_;
41   virtual std::unique_ptr<GrowStats> CreateGrowStats(int32 node_id,
42                                                      int32 depth) const;
43 
44   // Initialize from a previously serialized proto.
45   virtual void ExtractFromProto(const FertileStats& stats);
46 
47   // Serialize contents to the given proto.
48   virtual void PackToProto(FertileStats* stats) const;
49 
50   // Updates the slot's candidates with the new example.
51   // Assumes slot has been initialized.
52   virtual void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
53                           const InputTarget* target,
54                           const std::vector<int>& examples,
55                           int32 node_id) const;
56 
57   // Create a new candidate and initialize it with the given example.
58   virtual void CreateAndInitializeCandidateWithExample(
59       const std::unique_ptr<TensorDataSet>& input_data,
60       const InputTarget* target, int example, int32 node_id) const;
61 
62   // Create a new GrowStats for the given node id and initialize it.
63   virtual void InitializeSlot(int32 node_id, int32 depth);
64 
65   // Called when the resource is deserialized, possibly needing an
66   // initialization.
MaybeInitialize()67   virtual void MaybeInitialize() {
68     if (stats_.empty()) {
69       InitializeSlot(0, 0);
70     }
71   }
72 
73   // Perform any necessary cleanup for any tracked state for the slot.
ClearSlot(int32 node_id)74   virtual void ClearSlot(int32 node_id) { stats_.erase(node_id); }
75 
76   // Return true if slot is fully initialized.
77   virtual bool IsInitialized(int32 node_id) const;
78 
79   // Return true if slot is finished.
IsFinished(int32 node_id)80   virtual bool IsFinished(int32 node_id) const {
81     return stats_.at(node_id)->IsFinished();
82   }
83 
84   // Fill in best with the best split that node_id has, return true if this
85   // was successful, false if no good split was found.
86   virtual bool BestSplit(int32 node_id, SplitCandidate* best,
87                          int32* depth) const;
88 
89  protected:
90   const TensorForestParams& params_;
91   std::unordered_map<int32, std::unique_ptr<GrowStats>> stats_;
92 };
93 
94 class CollectionCreator {
95  public:
96   virtual std::unique_ptr<SplitCollectionOperator> Create(
97       const TensorForestParams& params) = 0;
~CollectionCreator()98   virtual ~CollectionCreator() {}
99 };
100 
101 class SplitCollectionOperatorFactory {
102  public:
103   static std::unique_ptr<SplitCollectionOperator> CreateSplitCollectionOperator(
104       const TensorForestParams& params);
105 
106   static std::unordered_map<int, CollectionCreator*> factories_;
107 };
108 
109 template <typename T>
110 class AnyCollectionCreator : public CollectionCreator {
111  public:
AnyCollectionCreator(SplitCollectionType type)112   AnyCollectionCreator(SplitCollectionType type) {
113     SplitCollectionOperatorFactory::factories_[type] = this;
114   }
Create(const TensorForestParams & params)115   virtual std::unique_ptr<SplitCollectionOperator> Create(
116       const TensorForestParams& params) {
117     return std::unique_ptr<SplitCollectionOperator>(new T(params));
118   }
119 };
120 
121 #define REGISTER_SPLIT_COLLECTION(name, cls) \
122   namespace {                                \
123   AnyCollectionCreator<cls> creator(name);   \
124   }
125 
126 }  // namespace tensorforest
127 }  // namespace tensorflow
128 
129 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_
130