1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ 16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ 17 18 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 19 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" 20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" 21 22 namespace tensorflow { 23 namespace tensorforest { 24 25 // Base class for evaluators of decision nodes that effectively copy proto 26 // contents into C++ structures for faster execution. 27 class DecisionNodeEvaluator { 28 public: ~DecisionNodeEvaluator()29 virtual ~DecisionNodeEvaluator() {} 30 31 // Returns the index of the child node. 32 virtual int32 Decide(const std::unique_ptr<TensorDataSet>& dataset, 33 int example) const = 0; 34 }; 35 36 // An evaluator for binary decisions with left and right children. 37 class BinaryDecisionNodeEvaluator : public DecisionNodeEvaluator { 38 protected: BinaryDecisionNodeEvaluator(int32 left,int32 right)39 BinaryDecisionNodeEvaluator(int32 left, int32 right) 40 : left_child_id_(left), right_child_id_(right) {} 41 42 int32 left_child_id_; 43 int32 right_child_id_; 44 }; 45 46 // Evaluator for basic inequality decisions (f[x] <= T). 47 class InequalityDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator { 48 public: 49 InequalityDecisionNodeEvaluator(const decision_trees::InequalityTest& test, 50 int32 left, int32 right); 51 52 int32 Decide(const std::unique_ptr<TensorDataSet>& dataset, 53 int example) const override; 54 55 protected: 56 int32 feature_num_; 57 float threshold_; 58 ::tensorflow::decision_trees::InequalityTest_Type _test_type; 59 }; 60 61 // Evaluator for splits with multiple weighted features. 62 class ObliqueInequalityDecisionNodeEvaluator 63 : public BinaryDecisionNodeEvaluator { 64 public: 65 ObliqueInequalityDecisionNodeEvaluator( 66 const decision_trees::InequalityTest& test, int32 left, int32 right); 67 68 int32 Decide(const std::unique_ptr<TensorDataSet>& dataset, 69 int example) const override; 70 71 protected: 72 std::vector<int32> feature_num_; 73 std::vector<float> feature_weights_; 74 float threshold_; 75 }; 76 77 // Evaluator for contains-in-set decisions. Also supports inverse (not-in-set). 78 class MatchingValuesDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator { 79 public: 80 MatchingValuesDecisionNodeEvaluator( 81 const decision_trees::MatchingValuesTest& test, int32 left, int32 right); 82 83 int32 Decide(const std::unique_ptr<TensorDataSet>& dataset, 84 int example) const override; 85 86 protected: 87 int32 feature_num_; 88 std::vector<float> values_; 89 bool inverse_; 90 }; 91 92 std::unique_ptr<DecisionNodeEvaluator> CreateDecisionNodeEvaluator( 93 const decision_trees::TreeNode& node); 94 std::unique_ptr<DecisionNodeEvaluator> CreateBinaryDecisionNodeEvaluator( 95 const decision_trees::BinaryNode& node, int32 left, int32 right); 96 97 struct CandidateEvalatorCollection { 98 std::vector<std::unique_ptr<DecisionNodeEvaluator>> splits; 99 }; 100 101 } // namespace tensorforest 102 } // namespace tensorflow 103 104 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ 105