1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
16 #include "tensorflow/core/lib/strings/numbers.h"
17 
18 namespace tensorflow {
19 namespace tensorforest {
20 
CreateDecisionNodeEvaluator(const decision_trees::TreeNode & node)21 std::unique_ptr<DecisionNodeEvaluator> CreateDecisionNodeEvaluator(
22     const decision_trees::TreeNode& node) {
23   const decision_trees::BinaryNode& bnode = node.binary_node();
24   return CreateBinaryDecisionNodeEvaluator(bnode, bnode.left_child_id().value(),
25                                            bnode.right_child_id().value());
26 }
27 
CreateBinaryDecisionNodeEvaluator(const decision_trees::BinaryNode & bnode,int32 left,int32 right)28 std::unique_ptr<DecisionNodeEvaluator> CreateBinaryDecisionNodeEvaluator(
29     const decision_trees::BinaryNode& bnode, int32 left, int32 right) {
30   if (bnode.has_inequality_left_child_test()) {
31     const auto& test = bnode.inequality_left_child_test();
32     if (test.has_oblique()) {
33       return std::unique_ptr<ObliqueInequalityDecisionNodeEvaluator>(
34           new ObliqueInequalityDecisionNodeEvaluator(test, left, right));
35     } else {
36       return std::unique_ptr<InequalityDecisionNodeEvaluator>(
37           new InequalityDecisionNodeEvaluator(test, left, right));
38     }
39   } else {
40     decision_trees::MatchingValuesTest test;
41     if (bnode.custom_left_child_test().UnpackTo(&test)) {
42       return std::unique_ptr<MatchingValuesDecisionNodeEvaluator>(
43           new MatchingValuesDecisionNodeEvaluator(test, left, right));
44     } else {
45       LOG(ERROR) << "Unknown split test: " << bnode.DebugString();
46       return nullptr;
47     }
48   }
49 }
50 
InequalityDecisionNodeEvaluator(const decision_trees::InequalityTest & test,int32 left,int32 right)51 InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator(
52     const decision_trees::InequalityTest& test, int32 left, int32 right)
53     : BinaryDecisionNodeEvaluator(left, right) {
54   CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_))
55       << "Invalid feature ID: [" << test.feature_id().id().value() << "]";
56   threshold_ = test.threshold().float_value();
57   _test_type = test.type();
58 }
59 
Decide(const std::unique_ptr<TensorDataSet> & dataset,int example) const60 int32 InequalityDecisionNodeEvaluator::Decide(
61     const std::unique_ptr<TensorDataSet>& dataset, int example) const {
62   const float val = dataset->GetExampleValue(example, feature_num_);
63   switch (_test_type) {
64     case decision_trees::InequalityTest::LESS_OR_EQUAL:
65       return val <= threshold_ ? left_child_id_ : right_child_id_;
66     case decision_trees::InequalityTest::LESS_THAN:
67       return val < threshold_ ? left_child_id_ : right_child_id_;
68     case decision_trees::InequalityTest::GREATER_OR_EQUAL:
69       return val >= threshold_ ? left_child_id_ : right_child_id_;
70     case decision_trees::InequalityTest::GREATER_THAN:
71       return val > threshold_ ? left_child_id_ : right_child_id_;
72     default:
73       LOG(ERROR) << "Unknown split test type: " << _test_type;
74       return -1;
75   }
76 }
77 
ObliqueInequalityDecisionNodeEvaluator(const decision_trees::InequalityTest & test,int32 left,int32 right)78 ObliqueInequalityDecisionNodeEvaluator::ObliqueInequalityDecisionNodeEvaluator(
79     const decision_trees::InequalityTest& test, int32 left, int32 right)
80     : BinaryDecisionNodeEvaluator(left, right) {
81   for (int i = 0; i < test.oblique().features_size(); ++i) {
82     int32 val;
83     CHECK(safe_strto32(test.oblique().features(i).id().value(), &val))
84         << "Invalid feature ID: [" << test.oblique().features(i).id().value()
85         << "]";
86     feature_num_.push_back(val);
87     feature_weights_.push_back(test.oblique().weights(i));
88   }
89   threshold_ = test.threshold().float_value();
90 }
91 
Decide(const std::unique_ptr<TensorDataSet> & dataset,int example) const92 int32 ObliqueInequalityDecisionNodeEvaluator::Decide(
93     const std::unique_ptr<TensorDataSet>& dataset, int example) const {
94   float val = 0;
95   for (int i = 0; i < feature_num_.size(); ++i) {
96     val += feature_weights_[i] *
97            dataset->GetExampleValue(example, feature_num_[i]);
98   }
99 
100   if (val <= threshold_) {
101     return left_child_id_;
102   } else {
103     return right_child_id_;
104   }
105 }
106 
MatchingValuesDecisionNodeEvaluator(const decision_trees::MatchingValuesTest & test,int32 left,int32 right)107 MatchingValuesDecisionNodeEvaluator::MatchingValuesDecisionNodeEvaluator(
108     const decision_trees::MatchingValuesTest& test, int32 left, int32 right)
109     : BinaryDecisionNodeEvaluator(left, right) {
110   CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_))
111       << "Invalid feature ID: [" << test.feature_id().id().value() << "]";
112   for (const auto& val : test.value()) {
113     values_.push_back(val.float_value());
114   }
115   inverse_ = test.inverse();
116 }
117 
Decide(const std::unique_ptr<TensorDataSet> & dataset,int example) const118 int32 MatchingValuesDecisionNodeEvaluator::Decide(
119     const std::unique_ptr<TensorDataSet>& dataset, int example) const {
120   const float val = dataset->GetExampleValue(example, feature_num_);
121   for (float testval : values_) {
122     if (val == testval) {
123       return inverse_ ? right_child_id_ : left_child_id_;
124     }
125   }
126 
127   return inverse_ ? left_child_id_ : right_child_id_;
128 }
129 
130 }  // namespace tensorforest
131 }  // namespace tensorflow
132