1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
16 #include "tensorflow/core/lib/strings/numbers.h"
17
18 namespace tensorflow {
19 namespace tensorforest {
20
CreateDecisionNodeEvaluator(const decision_trees::TreeNode & node)21 std::unique_ptr<DecisionNodeEvaluator> CreateDecisionNodeEvaluator(
22 const decision_trees::TreeNode& node) {
23 const decision_trees::BinaryNode& bnode = node.binary_node();
24 return CreateBinaryDecisionNodeEvaluator(bnode, bnode.left_child_id().value(),
25 bnode.right_child_id().value());
26 }
27
CreateBinaryDecisionNodeEvaluator(const decision_trees::BinaryNode & bnode,int32 left,int32 right)28 std::unique_ptr<DecisionNodeEvaluator> CreateBinaryDecisionNodeEvaluator(
29 const decision_trees::BinaryNode& bnode, int32 left, int32 right) {
30 if (bnode.has_inequality_left_child_test()) {
31 const auto& test = bnode.inequality_left_child_test();
32 if (test.has_oblique()) {
33 return std::unique_ptr<ObliqueInequalityDecisionNodeEvaluator>(
34 new ObliqueInequalityDecisionNodeEvaluator(test, left, right));
35 } else {
36 return std::unique_ptr<InequalityDecisionNodeEvaluator>(
37 new InequalityDecisionNodeEvaluator(test, left, right));
38 }
39 } else {
40 decision_trees::MatchingValuesTest test;
41 if (bnode.custom_left_child_test().UnpackTo(&test)) {
42 return std::unique_ptr<MatchingValuesDecisionNodeEvaluator>(
43 new MatchingValuesDecisionNodeEvaluator(test, left, right));
44 } else {
45 LOG(ERROR) << "Unknown split test: " << bnode.DebugString();
46 return nullptr;
47 }
48 }
49 }
50
InequalityDecisionNodeEvaluator(const decision_trees::InequalityTest & test,int32 left,int32 right)51 InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator(
52 const decision_trees::InequalityTest& test, int32 left, int32 right)
53 : BinaryDecisionNodeEvaluator(left, right) {
54 CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_))
55 << "Invalid feature ID: [" << test.feature_id().id().value() << "]";
56 threshold_ = test.threshold().float_value();
57 _test_type = test.type();
58 }
59
Decide(const std::unique_ptr<TensorDataSet> & dataset,int example) const60 int32 InequalityDecisionNodeEvaluator::Decide(
61 const std::unique_ptr<TensorDataSet>& dataset, int example) const {
62 const float val = dataset->GetExampleValue(example, feature_num_);
63 switch (_test_type) {
64 case decision_trees::InequalityTest::LESS_OR_EQUAL:
65 return val <= threshold_ ? left_child_id_ : right_child_id_;
66 case decision_trees::InequalityTest::LESS_THAN:
67 return val < threshold_ ? left_child_id_ : right_child_id_;
68 case decision_trees::InequalityTest::GREATER_OR_EQUAL:
69 return val >= threshold_ ? left_child_id_ : right_child_id_;
70 case decision_trees::InequalityTest::GREATER_THAN:
71 return val > threshold_ ? left_child_id_ : right_child_id_;
72 default:
73 LOG(ERROR) << "Unknown split test type: " << _test_type;
74 return -1;
75 }
76 }
77
ObliqueInequalityDecisionNodeEvaluator(const decision_trees::InequalityTest & test,int32 left,int32 right)78 ObliqueInequalityDecisionNodeEvaluator::ObliqueInequalityDecisionNodeEvaluator(
79 const decision_trees::InequalityTest& test, int32 left, int32 right)
80 : BinaryDecisionNodeEvaluator(left, right) {
81 for (int i = 0; i < test.oblique().features_size(); ++i) {
82 int32 val;
83 CHECK(safe_strto32(test.oblique().features(i).id().value(), &val))
84 << "Invalid feature ID: [" << test.oblique().features(i).id().value()
85 << "]";
86 feature_num_.push_back(val);
87 feature_weights_.push_back(test.oblique().weights(i));
88 }
89 threshold_ = test.threshold().float_value();
90 }
91
Decide(const std::unique_ptr<TensorDataSet> & dataset,int example) const92 int32 ObliqueInequalityDecisionNodeEvaluator::Decide(
93 const std::unique_ptr<TensorDataSet>& dataset, int example) const {
94 float val = 0;
95 for (int i = 0; i < feature_num_.size(); ++i) {
96 val += feature_weights_[i] *
97 dataset->GetExampleValue(example, feature_num_[i]);
98 }
99
100 if (val <= threshold_) {
101 return left_child_id_;
102 } else {
103 return right_child_id_;
104 }
105 }
106
MatchingValuesDecisionNodeEvaluator(const decision_trees::MatchingValuesTest & test,int32 left,int32 right)107 MatchingValuesDecisionNodeEvaluator::MatchingValuesDecisionNodeEvaluator(
108 const decision_trees::MatchingValuesTest& test, int32 left, int32 right)
109 : BinaryDecisionNodeEvaluator(left, right) {
110 CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_))
111 << "Invalid feature ID: [" << test.feature_id().id().value() << "]";
112 for (const auto& val : test.value()) {
113 values_.push_back(val.float_value());
114 }
115 inverse_ = test.inverse();
116 }
117
Decide(const std::unique_ptr<TensorDataSet> & dataset,int example) const118 int32 MatchingValuesDecisionNodeEvaluator::Decide(
119 const std::unique_ptr<TensorDataSet>& dataset, int example) const {
120 const float val = dataset->GetExampleValue(example, feature_num_);
121 for (float testval : values_) {
122 if (val == testval) {
123 return inverse_ ? right_child_id_ : left_child_id_;
124 }
125 }
126
127 return inverse_ ? left_child_id_ : right_child_id_;
128 }
129
130 } // namespace tensorforest
131 } // namespace tensorflow
132