1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ 16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ 17 18 #include <vector> 19 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 20 #include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" 21 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" 22 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" 23 #include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" 24 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" 25 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" 26 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" 27 28 namespace tensorflow { 29 namespace tensorforest { 30 31 // Class that can initialize and update split collections, and 32 // report if one is finished and ready to split. Designed to be inherited 33 // from to implement techniques such as pruning and early/delayed finishing. 34 class SplitCollectionOperator { 35 public: SplitCollectionOperator(const TensorForestParams & params)36 explicit SplitCollectionOperator(const TensorForestParams& params) 37 : params_(params) {} ~SplitCollectionOperator()38 virtual ~SplitCollectionOperator() {} 39 40 // Return a new GrowStats object according to stats_type_; 41 virtual std::unique_ptr<GrowStats> CreateGrowStats(int32 node_id, 42 int32 depth) const; 43 44 // Initialize from a previously serialized proto. 45 virtual void ExtractFromProto(const FertileStats& stats); 46 47 // Serialize contents to the given proto. 48 virtual void PackToProto(FertileStats* stats) const; 49 50 // Updates the slot's candidates with the new example. 51 // Assumes slot has been initialized. 52 virtual void AddExample(const std::unique_ptr<TensorDataSet>& input_data, 53 const InputTarget* target, 54 const std::vector<int>& examples, 55 int32 node_id) const; 56 57 // Create a new candidate and initialize it with the given example. 58 virtual void CreateAndInitializeCandidateWithExample( 59 const std::unique_ptr<TensorDataSet>& input_data, 60 const InputTarget* target, int example, int32 node_id) const; 61 62 // Create a new GrowStats for the given node id and initialize it. 63 virtual void InitializeSlot(int32 node_id, int32 depth); 64 65 // Called when the resource is deserialized, possibly needing an 66 // initialization. MaybeInitialize()67 virtual void MaybeInitialize() { 68 if (stats_.empty()) { 69 InitializeSlot(0, 0); 70 } 71 } 72 73 // Perform any necessary cleanup for any tracked state for the slot. ClearSlot(int32 node_id)74 virtual void ClearSlot(int32 node_id) { stats_.erase(node_id); } 75 76 // Return true if slot is fully initialized. 77 virtual bool IsInitialized(int32 node_id) const; 78 79 // Return true if slot is finished. IsFinished(int32 node_id)80 virtual bool IsFinished(int32 node_id) const { 81 return stats_.at(node_id)->IsFinished(); 82 } 83 84 // Fill in best with the best split that node_id has, return true if this 85 // was successful, false if no good split was found. 86 virtual bool BestSplit(int32 node_id, SplitCandidate* best, 87 int32* depth) const; 88 89 protected: 90 const TensorForestParams& params_; 91 std::unordered_map<int32, std::unique_ptr<GrowStats>> stats_; 92 }; 93 94 class CollectionCreator { 95 public: 96 virtual std::unique_ptr<SplitCollectionOperator> Create( 97 const TensorForestParams& params) = 0; ~CollectionCreator()98 virtual ~CollectionCreator() {} 99 }; 100 101 class SplitCollectionOperatorFactory { 102 public: 103 static std::unique_ptr<SplitCollectionOperator> CreateSplitCollectionOperator( 104 const TensorForestParams& params); 105 106 static std::unordered_map<int, CollectionCreator*> factories_; 107 }; 108 109 template <typename T> 110 class AnyCollectionCreator : public CollectionCreator { 111 public: AnyCollectionCreator(SplitCollectionType type)112 AnyCollectionCreator(SplitCollectionType type) { 113 SplitCollectionOperatorFactory::factories_[type] = this; 114 } Create(const TensorForestParams & params)115 virtual std::unique_ptr<SplitCollectionOperator> Create( 116 const TensorForestParams& params) { 117 return std::unique_ptr<SplitCollectionOperator>(new T(params)); 118 } 119 }; 120 121 #define REGISTER_SPLIT_COLLECTION(name, cls) \ 122 namespace { \ 123 AnyCollectionCreator<cls> creator(name); \ 124 } 125 126 } // namespace tensorforest 127 } // namespace tensorflow 128 129 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ 130