1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
17 
18 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
19 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
21 
22 namespace tensorflow {
23 namespace tensorforest {
24 
25 // Base class for evaluators of decision nodes that effectively copy proto
26 // contents into C++ structures for faster execution.
27 class DecisionNodeEvaluator {
28  public:
~DecisionNodeEvaluator()29   virtual ~DecisionNodeEvaluator() {}
30 
31   // Returns the index of the child node.
32   virtual int32 Decide(const std::unique_ptr<TensorDataSet>& dataset,
33                        int example) const = 0;
34 };
35 
36 // An evaluator for binary decisions with left and right children.
37 class BinaryDecisionNodeEvaluator : public DecisionNodeEvaluator {
38  protected:
BinaryDecisionNodeEvaluator(int32 left,int32 right)39   BinaryDecisionNodeEvaluator(int32 left, int32 right)
40       : left_child_id_(left), right_child_id_(right) {}
41 
42   int32 left_child_id_;
43   int32 right_child_id_;
44 };
45 
46 // Evaluator for basic inequality decisions (f[x] <= T).
47 class InequalityDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator {
48  public:
49   InequalityDecisionNodeEvaluator(const decision_trees::InequalityTest& test,
50                                   int32 left, int32 right);
51 
52   int32 Decide(const std::unique_ptr<TensorDataSet>& dataset,
53                int example) const override;
54 
55  protected:
56   int32 feature_num_;
57   float threshold_;
58   ::tensorflow::decision_trees::InequalityTest_Type _test_type;
59 };
60 
61 // Evaluator for splits with multiple weighted features.
62 class ObliqueInequalityDecisionNodeEvaluator
63     : public BinaryDecisionNodeEvaluator {
64  public:
65   ObliqueInequalityDecisionNodeEvaluator(
66       const decision_trees::InequalityTest& test, int32 left, int32 right);
67 
68   int32 Decide(const std::unique_ptr<TensorDataSet>& dataset,
69                int example) const override;
70 
71  protected:
72   std::vector<int32> feature_num_;
73   std::vector<float> feature_weights_;
74   float threshold_;
75 };
76 
77 // Evaluator for contains-in-set decisions.  Also supports inverse (not-in-set).
78 class MatchingValuesDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator {
79  public:
80   MatchingValuesDecisionNodeEvaluator(
81       const decision_trees::MatchingValuesTest& test, int32 left, int32 right);
82 
83   int32 Decide(const std::unique_ptr<TensorDataSet>& dataset,
84                int example) const override;
85 
86  protected:
87   int32 feature_num_;
88   std::vector<float> values_;
89   bool inverse_;
90 };
91 
92 std::unique_ptr<DecisionNodeEvaluator> CreateDecisionNodeEvaluator(
93     const decision_trees::TreeNode& node);
94 std::unique_ptr<DecisionNodeEvaluator> CreateBinaryDecisionNodeEvaluator(
95     const decision_trees::BinaryNode& node, int32 left, int32 right);
96 
97 struct CandidateEvalatorCollection {
98   std::vector<std::unique_ptr<DecisionNodeEvaluator>> splits;
99 };
100 
101 }  // namespace tensorforest
102 }  // namespace tensorflow
103 
104 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
105