• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
16 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
17 #include "tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h"
18 #include "tensorflow/core/platform/test.h"
19 
20 namespace tensorflow {
21 namespace {
22 
23 using tensorflow::decision_trees::InequalityTest;
24 using tensorflow::decision_trees::MatchingValuesTest;
25 using tensorflow::tensorforest::InequalityDecisionNodeEvaluator;
26 using tensorflow::tensorforest::MatchingValuesDecisionNodeEvaluator;
27 using tensorflow::tensorforest::ObliqueInequalityDecisionNodeEvaluator;
28 
TEST(InequalityDecisionNodeEvaluatorTest,TestLessOrEqual)29 TEST(InequalityDecisionNodeEvaluatorTest, TestLessOrEqual) {
30   InequalityTest test;
31   test.mutable_feature_id()->mutable_id()->set_value("0");
32   test.mutable_threshold()->set_float_value(3.0);
33   test.set_type(InequalityTest::LESS_OR_EQUAL);
34   std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
35       new InequalityDecisionNodeEvaluator(test, 0, 1));
36 
37   std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
38       new tensorflow::tensorforest::TestableDataSet(
39           {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
40 
41   ASSERT_EQ(eval->Decide(dataset, 2), 0);
42   ASSERT_EQ(eval->Decide(dataset, 3), 0);
43   ASSERT_EQ(eval->Decide(dataset, 4), 1);
44 }
45 
TEST(InequalityDecisionNodeEvaluatorTest,TestStrictlyLess)46 TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyLess) {
47   InequalityTest test;
48   test.mutable_feature_id()->mutable_id()->set_value("0");
49   test.mutable_threshold()->set_float_value(3.0);
50   test.set_type(InequalityTest::LESS_THAN);
51   std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
52       new InequalityDecisionNodeEvaluator(test, 0, 1));
53 
54   std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
55       new tensorflow::tensorforest::TestableDataSet(
56           {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
57 
58   ASSERT_EQ(eval->Decide(dataset, 2), 0);
59   ASSERT_EQ(eval->Decide(dataset, 3), 1);
60   ASSERT_EQ(eval->Decide(dataset, 4), 1);
61 }
62 
TEST(InequalityDecisionNodeEvaluatorTest,TestGreaterOrEqual)63 TEST(InequalityDecisionNodeEvaluatorTest, TestGreaterOrEqual) {
64   InequalityTest test;
65   test.mutable_feature_id()->mutable_id()->set_value("0");
66   test.mutable_threshold()->set_float_value(3.0);
67   test.set_type(InequalityTest::GREATER_OR_EQUAL);
68   std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
69       new InequalityDecisionNodeEvaluator(test, 0, 1));
70 
71   std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
72       new tensorflow::tensorforest::TestableDataSet(
73           {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
74 
75   ASSERT_EQ(eval->Decide(dataset, 2), 1);
76   ASSERT_EQ(eval->Decide(dataset, 3), 0);
77   ASSERT_EQ(eval->Decide(dataset, 4), 0);
78 }
79 
TEST(InequalityDecisionNodeEvaluatorTest,TestStrictlyGreater)80 TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyGreater) {
81   InequalityTest test;
82   test.mutable_feature_id()->mutable_id()->set_value("0");
83   test.mutable_threshold()->set_float_value(3.0);
84   test.set_type(InequalityTest::GREATER_THAN);
85   std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
86       new InequalityDecisionNodeEvaluator(test, 0, 1));
87 
88   std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
89       new tensorflow::tensorforest::TestableDataSet(
90           {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
91 
92   ASSERT_EQ(eval->Decide(dataset, 2), 1);
93   ASSERT_EQ(eval->Decide(dataset, 3), 1);
94   ASSERT_EQ(eval->Decide(dataset, 4), 0);
95 }
96 
TEST(MatchingDecisionNodeEvaluatorTest,Basic)97 TEST(MatchingDecisionNodeEvaluatorTest, Basic) {
98   MatchingValuesTest test;
99   test.mutable_feature_id()->mutable_id()->set_value("0");
100   test.add_value()->set_float_value(3.0);
101   test.add_value()->set_float_value(5.0);
102 
103   std::unique_ptr<MatchingValuesDecisionNodeEvaluator> eval(
104       new MatchingValuesDecisionNodeEvaluator(test, 0, 1));
105 
106   std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
107       new tensorflow::tensorforest::TestableDataSet(
108           {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
109 
110   ASSERT_EQ(eval->Decide(dataset, 2), 1);
111   ASSERT_EQ(eval->Decide(dataset, 3), 0);
112   ASSERT_EQ(eval->Decide(dataset, 4), 1);
113   ASSERT_EQ(eval->Decide(dataset, 5), 0);
114 }
115 
TEST(MatchingDecisionNodeEvaluatorTest,Inverse)116 TEST(MatchingDecisionNodeEvaluatorTest, Inverse) {
117   MatchingValuesTest test;
118   test.mutable_feature_id()->mutable_id()->set_value("0");
119   test.add_value()->set_float_value(3.0);
120   test.add_value()->set_float_value(5.0);
121   test.set_inverse(true);
122 
123   std::unique_ptr<MatchingValuesDecisionNodeEvaluator> eval(
124       new MatchingValuesDecisionNodeEvaluator(test, 0, 1));
125 
126   std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
127       new tensorflow::tensorforest::TestableDataSet(
128           {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
129 
130   ASSERT_EQ(eval->Decide(dataset, 2), 0);
131   ASSERT_EQ(eval->Decide(dataset, 3), 1);
132   ASSERT_EQ(eval->Decide(dataset, 4), 0);
133   ASSERT_EQ(eval->Decide(dataset, 5), 1);
134 }
135 
TEST(ObliqueDecisionNodeEvaluatorTest,Basic)136 TEST(ObliqueDecisionNodeEvaluatorTest, Basic) {
137   InequalityTest test;
138   auto* feat1 = test.mutable_oblique()->add_features();
139   feat1->mutable_id()->set_value("0");
140   test.mutable_oblique()->add_weights(1.0);
141   auto* feat2 = test.mutable_oblique()->add_features();
142   feat2->mutable_id()->set_value("1");
143   test.mutable_oblique()->add_weights(1.0);
144 
145   test.mutable_threshold()->set_float_value(3.0);
146   test.set_type(InequalityTest::LESS_OR_EQUAL);
147 
148   std::unique_ptr<ObliqueInequalityDecisionNodeEvaluator> eval(
149       new ObliqueInequalityDecisionNodeEvaluator(test, 0, 1));
150 
151   std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
152       new tensorflow::tensorforest::TestableDataSet(
153           {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 2));
154 
155   ASSERT_EQ(eval->Decide(dataset, 0), 0);
156   ASSERT_EQ(eval->Decide(dataset, 1), 1);
157 }
158 
159 }  // namespace
160 }  // namespace tensorflow
161