1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
16 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
17 #include "tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h"
18 #include "tensorflow/core/platform/test.h"
19
20 namespace tensorflow {
21 namespace {
22
23 using tensorflow::decision_trees::InequalityTest;
24 using tensorflow::decision_trees::MatchingValuesTest;
25 using tensorflow::tensorforest::InequalityDecisionNodeEvaluator;
26 using tensorflow::tensorforest::MatchingValuesDecisionNodeEvaluator;
27 using tensorflow::tensorforest::ObliqueInequalityDecisionNodeEvaluator;
28
TEST(InequalityDecisionNodeEvaluatorTest,TestLessOrEqual)29 TEST(InequalityDecisionNodeEvaluatorTest, TestLessOrEqual) {
30 InequalityTest test;
31 test.mutable_feature_id()->mutable_id()->set_value("0");
32 test.mutable_threshold()->set_float_value(3.0);
33 test.set_type(InequalityTest::LESS_OR_EQUAL);
34 std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
35 new InequalityDecisionNodeEvaluator(test, 0, 1));
36
37 std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
38 new tensorflow::tensorforest::TestableDataSet(
39 {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
40
41 ASSERT_EQ(eval->Decide(dataset, 2), 0);
42 ASSERT_EQ(eval->Decide(dataset, 3), 0);
43 ASSERT_EQ(eval->Decide(dataset, 4), 1);
44 }
45
TEST(InequalityDecisionNodeEvaluatorTest,TestStrictlyLess)46 TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyLess) {
47 InequalityTest test;
48 test.mutable_feature_id()->mutable_id()->set_value("0");
49 test.mutable_threshold()->set_float_value(3.0);
50 test.set_type(InequalityTest::LESS_THAN);
51 std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
52 new InequalityDecisionNodeEvaluator(test, 0, 1));
53
54 std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
55 new tensorflow::tensorforest::TestableDataSet(
56 {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
57
58 ASSERT_EQ(eval->Decide(dataset, 2), 0);
59 ASSERT_EQ(eval->Decide(dataset, 3), 1);
60 ASSERT_EQ(eval->Decide(dataset, 4), 1);
61 }
62
TEST(InequalityDecisionNodeEvaluatorTest,TestGreaterOrEqual)63 TEST(InequalityDecisionNodeEvaluatorTest, TestGreaterOrEqual) {
64 InequalityTest test;
65 test.mutable_feature_id()->mutable_id()->set_value("0");
66 test.mutable_threshold()->set_float_value(3.0);
67 test.set_type(InequalityTest::GREATER_OR_EQUAL);
68 std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
69 new InequalityDecisionNodeEvaluator(test, 0, 1));
70
71 std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
72 new tensorflow::tensorforest::TestableDataSet(
73 {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
74
75 ASSERT_EQ(eval->Decide(dataset, 2), 1);
76 ASSERT_EQ(eval->Decide(dataset, 3), 0);
77 ASSERT_EQ(eval->Decide(dataset, 4), 0);
78 }
79
TEST(InequalityDecisionNodeEvaluatorTest,TestStrictlyGreater)80 TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyGreater) {
81 InequalityTest test;
82 test.mutable_feature_id()->mutable_id()->set_value("0");
83 test.mutable_threshold()->set_float_value(3.0);
84 test.set_type(InequalityTest::GREATER_THAN);
85 std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
86 new InequalityDecisionNodeEvaluator(test, 0, 1));
87
88 std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
89 new tensorflow::tensorforest::TestableDataSet(
90 {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
91
92 ASSERT_EQ(eval->Decide(dataset, 2), 1);
93 ASSERT_EQ(eval->Decide(dataset, 3), 1);
94 ASSERT_EQ(eval->Decide(dataset, 4), 0);
95 }
96
TEST(MatchingDecisionNodeEvaluatorTest,Basic)97 TEST(MatchingDecisionNodeEvaluatorTest, Basic) {
98 MatchingValuesTest test;
99 test.mutable_feature_id()->mutable_id()->set_value("0");
100 test.add_value()->set_float_value(3.0);
101 test.add_value()->set_float_value(5.0);
102
103 std::unique_ptr<MatchingValuesDecisionNodeEvaluator> eval(
104 new MatchingValuesDecisionNodeEvaluator(test, 0, 1));
105
106 std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
107 new tensorflow::tensorforest::TestableDataSet(
108 {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
109
110 ASSERT_EQ(eval->Decide(dataset, 2), 1);
111 ASSERT_EQ(eval->Decide(dataset, 3), 0);
112 ASSERT_EQ(eval->Decide(dataset, 4), 1);
113 ASSERT_EQ(eval->Decide(dataset, 5), 0);
114 }
115
TEST(MatchingDecisionNodeEvaluatorTest,Inverse)116 TEST(MatchingDecisionNodeEvaluatorTest, Inverse) {
117 MatchingValuesTest test;
118 test.mutable_feature_id()->mutable_id()->set_value("0");
119 test.add_value()->set_float_value(3.0);
120 test.add_value()->set_float_value(5.0);
121 test.set_inverse(true);
122
123 std::unique_ptr<MatchingValuesDecisionNodeEvaluator> eval(
124 new MatchingValuesDecisionNodeEvaluator(test, 0, 1));
125
126 std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
127 new tensorflow::tensorforest::TestableDataSet(
128 {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
129
130 ASSERT_EQ(eval->Decide(dataset, 2), 0);
131 ASSERT_EQ(eval->Decide(dataset, 3), 1);
132 ASSERT_EQ(eval->Decide(dataset, 4), 0);
133 ASSERT_EQ(eval->Decide(dataset, 5), 1);
134 }
135
TEST(ObliqueDecisionNodeEvaluatorTest,Basic)136 TEST(ObliqueDecisionNodeEvaluatorTest, Basic) {
137 InequalityTest test;
138 auto* feat1 = test.mutable_oblique()->add_features();
139 feat1->mutable_id()->set_value("0");
140 test.mutable_oblique()->add_weights(1.0);
141 auto* feat2 = test.mutable_oblique()->add_features();
142 feat2->mutable_id()->set_value("1");
143 test.mutable_oblique()->add_weights(1.0);
144
145 test.mutable_threshold()->set_float_value(3.0);
146 test.set_type(InequalityTest::LESS_OR_EQUAL);
147
148 std::unique_ptr<ObliqueInequalityDecisionNodeEvaluator> eval(
149 new ObliqueInequalityDecisionNodeEvaluator(test, 0, 1));
150
151 std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
152 new tensorflow::tensorforest::TestableDataSet(
153 {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 2));
154
155 ASSERT_EQ(eval->Decide(dataset, 0), 0);
156 ASSERT_EQ(eval->Decide(dataset, 1), 1);
157 }
158
159 } // namespace
160 } // namespace tensorflow
161