1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <algorithm>
16
17 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
18 #include "tensorflow/core/platform/macros.h"
19
20 namespace tensorflow {
21 namespace boosted_trees {
22 namespace trees {
23
24 constexpr int kInvalidLeaf = -1;
Traverse(const DecisionTreeConfig & config,const int32 sub_root_id,const utils::Example & example)25 int DecisionTree::Traverse(const DecisionTreeConfig& config,
26 const int32 sub_root_id,
27 const utils::Example& example) {
28 if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
29 return kInvalidLeaf;
30 }
31 // Traverse tree starting at the provided sub-root.
32 int32 node_id = sub_root_id;
33 // The index of the leave that holds this example in the oblivious case.
34 int oblivious_leaf_idx = 0;
35 while (true) {
36 const auto& current_node = config.nodes(node_id);
37 switch (current_node.node_case()) {
38 case TreeNode::kLeaf: {
39 return node_id + oblivious_leaf_idx;
40 }
41 case TreeNode::kDenseFloatBinarySplit: {
42 const auto& split = current_node.dense_float_binary_split();
43 node_id = example.dense_float_features[split.feature_column()] <=
44 split.threshold()
45 ? split.left_id()
46 : split.right_id();
47 break;
48 }
49 case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
50 const auto& split =
51 current_node.sparse_float_binary_split_default_left().split();
52 auto sparse_feature =
53 example.sparse_float_features[split.feature_column()];
54 // Feature id for the split when multivalent sparse float column, or 0
55 // by default.
56 const int32 dimension_id = split.dimension_id();
57
58 node_id = !sparse_feature[dimension_id].has_value() ||
59 sparse_feature[dimension_id].get_value() <=
60 split.threshold()
61 ? split.left_id()
62 : split.right_id();
63 break;
64 }
65 case TreeNode::kSparseFloatBinarySplitDefaultRight: {
66 const auto& split =
67 current_node.sparse_float_binary_split_default_right().split();
68 auto sparse_feature =
69 example.sparse_float_features[split.feature_column()];
70 // Feature id for the split when multivalent sparse float column, or 0
71 // by default.
72 const int32 dimension_id = split.dimension_id();
73 node_id = sparse_feature[dimension_id].has_value() &&
74 sparse_feature[dimension_id].get_value() <=
75 split.threshold()
76 ? split.left_id()
77 : split.right_id();
78 break;
79 }
80 case TreeNode::kCategoricalIdBinarySplit: {
81 const auto& split = current_node.categorical_id_binary_split();
82 const auto& features =
83 example.sparse_int_features[split.feature_column()];
84 node_id = (std::find(features.begin(), features.end(),
85 split.feature_id()) == features.end())
86 ? split.right_id()
87 : split.left_id();
88 break;
89 }
90 case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
91 const auto& split =
92 current_node.categorical_id_set_membership_binary_split();
93 // The new node_id = left_id if a feature is found, or right_id.
94 node_id = split.right_id();
95 for (const int64 feature_id :
96 example.sparse_int_features[split.feature_column()]) {
97 if (std::binary_search(split.feature_ids().begin(),
98 split.feature_ids().end(), feature_id)) {
99 node_id = split.left_id();
100 break;
101 }
102 }
103 break;
104 }
105 case TreeNode::kObliviousDenseFloatBinarySplit: {
106 const auto& split = current_node.oblivious_dense_float_binary_split();
107 oblivious_leaf_idx <<= 1;
108 if (example.dense_float_features[split.feature_column()] >
109 split.threshold()) {
110 oblivious_leaf_idx++;
111 }
112 node_id++;
113 break;
114 }
115 case TreeNode::kObliviousCategoricalIdBinarySplit: {
116 const auto& split =
117 current_node.oblivious_categorical_id_binary_split();
118 oblivious_leaf_idx <<= 1;
119 const auto& features =
120 example.sparse_int_features[split.feature_column()];
121 if (std::find(features.begin(), features.end(), split.feature_id()) ==
122 features.end()) {
123 oblivious_leaf_idx++;
124 }
125 node_id++;
126 break;
127 }
128 case TreeNode::NODE_NOT_SET: {
129 LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
130 break;
131 }
132 }
133 DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
134 << current_node.DebugString();
135 }
136 }
137
LinkChildren(const std::vector<int32> & children,TreeNode * parent_node)138 void DecisionTree::LinkChildren(const std::vector<int32>& children,
139 TreeNode* parent_node) {
140 // Decide how to link children depending on the parent node's type.
141 auto children_it = children.begin();
142 switch (parent_node->node_case()) {
143 case TreeNode::kLeaf: {
144 // Essentially no-op.
145 QCHECK(children.empty()) << "A leaf node cannot have children.";
146 break;
147 }
148 case TreeNode::kDenseFloatBinarySplit: {
149 QCHECK(children.size() == 2)
150 << "A binary split node must have exactly two children.";
151 auto* split = parent_node->mutable_dense_float_binary_split();
152 split->set_left_id(*children_it);
153 split->set_right_id(*++children_it);
154 break;
155 }
156 case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
157 QCHECK(children.size() == 2)
158 << "A binary split node must have exactly two children.";
159 auto* split =
160 parent_node->mutable_sparse_float_binary_split_default_left()
161 ->mutable_split();
162 split->set_left_id(*children_it);
163 split->set_right_id(*++children_it);
164 break;
165 }
166 case TreeNode::kSparseFloatBinarySplitDefaultRight: {
167 QCHECK(children.size() == 2)
168 << "A binary split node must have exactly two children.";
169 auto* split =
170 parent_node->mutable_sparse_float_binary_split_default_right()
171 ->mutable_split();
172 split->set_left_id(*children_it);
173 split->set_right_id(*++children_it);
174 break;
175 }
176 case TreeNode::kCategoricalIdBinarySplit: {
177 QCHECK(children.size() == 2)
178 << "A binary split node must have exactly two children.";
179 auto* split = parent_node->mutable_categorical_id_binary_split();
180 split->set_left_id(*children_it);
181 split->set_right_id(*++children_it);
182 break;
183 }
184 case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
185 QCHECK(children.size() == 2)
186 << "A binary split node must have exactly two children.";
187 auto* split =
188 parent_node->mutable_categorical_id_set_membership_binary_split();
189 split->set_left_id(*children_it);
190 split->set_right_id(*++children_it);
191 break;
192 }
193 case TreeNode::kObliviousDenseFloatBinarySplit: {
194 LOG(QFATAL)
195 << "Not implemented for the ObliviousDenseFloatBinarySplit case.";
196 break;
197 }
198 case TreeNode::kObliviousCategoricalIdBinarySplit: {
199 LOG(QFATAL)
200 << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
201 break;
202 }
203 case TreeNode::NODE_NOT_SET: {
204 LOG(QFATAL) << "A non-set node cannot have children.";
205 break;
206 }
207 }
208 }
209
GetChildren(const TreeNode & node)210 std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
211 // A node's children depend on its type.
212 switch (node.node_case()) {
213 case TreeNode::kLeaf: {
214 return {};
215 }
216 case TreeNode::kDenseFloatBinarySplit: {
217 const auto& split = node.dense_float_binary_split();
218 return {split.left_id(), split.right_id()};
219 }
220 case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
221 const auto& split = node.sparse_float_binary_split_default_left().split();
222 return {split.left_id(), split.right_id()};
223 }
224 case TreeNode::kSparseFloatBinarySplitDefaultRight: {
225 const auto& split =
226 node.sparse_float_binary_split_default_right().split();
227 return {split.left_id(), split.right_id()};
228 }
229 case TreeNode::kCategoricalIdBinarySplit: {
230 const auto& split = node.categorical_id_binary_split();
231 return {split.left_id(), split.right_id()};
232 }
233 case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
234 const auto& split = node.categorical_id_set_membership_binary_split();
235 return {split.left_id(), split.right_id()};
236 }
237 case TreeNode::kObliviousDenseFloatBinarySplit: {
238 LOG(QFATAL)
239 << "Not implemented for the ObliviousDenseFloatBinarySplit case.";
240 return {};
241 }
242 case TreeNode::kObliviousCategoricalIdBinarySplit: {
243 LOG(QFATAL)
244 << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
245 break;
246 }
247 case TreeNode::NODE_NOT_SET: {
248 return {};
249 }
250 }
251 }
252
253 } // namespace trees
254 } // namespace boosted_trees
255 } // namespace tensorflow
256