1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <algorithm>
16 
17 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
18 #include "tensorflow/core/platform/macros.h"
19 
20 namespace tensorflow {
21 namespace boosted_trees {
22 namespace trees {
23 
24 constexpr int kInvalidLeaf = -1;
Traverse(const DecisionTreeConfig & config,const int32 sub_root_id,const utils::Example & example)25 int DecisionTree::Traverse(const DecisionTreeConfig& config,
26                            const int32 sub_root_id,
27                            const utils::Example& example) {
28   if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
29     return kInvalidLeaf;
30   }
31   // Traverse tree starting at the provided sub-root.
32   int32 node_id = sub_root_id;
33   // The index of the leave that holds this example in the oblivious case.
34   int oblivious_leaf_idx = 0;
35   while (true) {
36     const auto& current_node = config.nodes(node_id);
37     switch (current_node.node_case()) {
38       case TreeNode::kLeaf: {
39         return node_id + oblivious_leaf_idx;
40       }
41       case TreeNode::kDenseFloatBinarySplit: {
42         const auto& split = current_node.dense_float_binary_split();
43         node_id = example.dense_float_features[split.feature_column()] <=
44                           split.threshold()
45                       ? split.left_id()
46                       : split.right_id();
47         break;
48       }
49       case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
50         const auto& split =
51             current_node.sparse_float_binary_split_default_left().split();
52         auto sparse_feature =
53             example.sparse_float_features[split.feature_column()];
54         // Feature id for the split when multivalent sparse float column, or 0
55         // by default.
56         const int32 dimension_id = split.dimension_id();
57 
58         node_id = !sparse_feature[dimension_id].has_value() ||
59                           sparse_feature[dimension_id].get_value() <=
60                               split.threshold()
61                       ? split.left_id()
62                       : split.right_id();
63         break;
64       }
65       case TreeNode::kSparseFloatBinarySplitDefaultRight: {
66         const auto& split =
67             current_node.sparse_float_binary_split_default_right().split();
68         auto sparse_feature =
69             example.sparse_float_features[split.feature_column()];
70         // Feature id for the split when multivalent sparse float column, or 0
71         // by default.
72         const int32 dimension_id = split.dimension_id();
73         node_id = sparse_feature[dimension_id].has_value() &&
74                           sparse_feature[dimension_id].get_value() <=
75                               split.threshold()
76                       ? split.left_id()
77                       : split.right_id();
78         break;
79       }
80       case TreeNode::kCategoricalIdBinarySplit: {
81         const auto& split = current_node.categorical_id_binary_split();
82         const auto& features =
83             example.sparse_int_features[split.feature_column()];
84         node_id = (std::find(features.begin(), features.end(),
85                              split.feature_id()) == features.end())
86                       ? split.right_id()
87                       : split.left_id();
88         break;
89       }
90       case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
91         const auto& split =
92             current_node.categorical_id_set_membership_binary_split();
93         // The new node_id = left_id if a feature is found, or right_id.
94         node_id = split.right_id();
95         for (const int64 feature_id :
96              example.sparse_int_features[split.feature_column()]) {
97           if (std::binary_search(split.feature_ids().begin(),
98                                  split.feature_ids().end(), feature_id)) {
99             node_id = split.left_id();
100             break;
101           }
102         }
103         break;
104       }
105       case TreeNode::kObliviousDenseFloatBinarySplit: {
106         const auto& split = current_node.oblivious_dense_float_binary_split();
107         oblivious_leaf_idx <<= 1;
108         if (example.dense_float_features[split.feature_column()] >
109             split.threshold()) {
110           oblivious_leaf_idx++;
111         }
112         node_id++;
113         break;
114       }
115       case TreeNode::kObliviousCategoricalIdBinarySplit: {
116         const auto& split =
117             current_node.oblivious_categorical_id_binary_split();
118         oblivious_leaf_idx <<= 1;
119         const auto& features =
120             example.sparse_int_features[split.feature_column()];
121         if (std::find(features.begin(), features.end(), split.feature_id()) ==
122             features.end()) {
123           oblivious_leaf_idx++;
124         }
125         node_id++;
126         break;
127       }
128       case TreeNode::NODE_NOT_SET: {
129         LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
130         break;
131       }
132     }
133     DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
134                           << current_node.DebugString();
135   }
136 }
137 
LinkChildren(const std::vector<int32> & children,TreeNode * parent_node)138 void DecisionTree::LinkChildren(const std::vector<int32>& children,
139                                 TreeNode* parent_node) {
140   // Decide how to link children depending on the parent node's type.
141   auto children_it = children.begin();
142   switch (parent_node->node_case()) {
143     case TreeNode::kLeaf: {
144       // Essentially no-op.
145       QCHECK(children.empty()) << "A leaf node cannot have children.";
146       break;
147     }
148     case TreeNode::kDenseFloatBinarySplit: {
149       QCHECK(children.size() == 2)
150           << "A binary split node must have exactly two children.";
151       auto* split = parent_node->mutable_dense_float_binary_split();
152       split->set_left_id(*children_it);
153       split->set_right_id(*++children_it);
154       break;
155     }
156     case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
157       QCHECK(children.size() == 2)
158           << "A binary split node must have exactly two children.";
159       auto* split =
160           parent_node->mutable_sparse_float_binary_split_default_left()
161               ->mutable_split();
162       split->set_left_id(*children_it);
163       split->set_right_id(*++children_it);
164       break;
165     }
166     case TreeNode::kSparseFloatBinarySplitDefaultRight: {
167       QCHECK(children.size() == 2)
168           << "A binary split node must have exactly two children.";
169       auto* split =
170           parent_node->mutable_sparse_float_binary_split_default_right()
171               ->mutable_split();
172       split->set_left_id(*children_it);
173       split->set_right_id(*++children_it);
174       break;
175     }
176     case TreeNode::kCategoricalIdBinarySplit: {
177       QCHECK(children.size() == 2)
178           << "A binary split node must have exactly two children.";
179       auto* split = parent_node->mutable_categorical_id_binary_split();
180       split->set_left_id(*children_it);
181       split->set_right_id(*++children_it);
182       break;
183     }
184     case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
185       QCHECK(children.size() == 2)
186           << "A binary split node must have exactly two children.";
187       auto* split =
188           parent_node->mutable_categorical_id_set_membership_binary_split();
189       split->set_left_id(*children_it);
190       split->set_right_id(*++children_it);
191       break;
192     }
193     case TreeNode::kObliviousDenseFloatBinarySplit: {
194       LOG(QFATAL)
195           << "Not implemented for the ObliviousDenseFloatBinarySplit case.";
196       break;
197     }
198     case TreeNode::kObliviousCategoricalIdBinarySplit: {
199       LOG(QFATAL)
200           << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
201       break;
202     }
203     case TreeNode::NODE_NOT_SET: {
204       LOG(QFATAL) << "A non-set node cannot have children.";
205       break;
206     }
207   }
208 }
209 
GetChildren(const TreeNode & node)210 std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
211   // A node's children depend on its type.
212   switch (node.node_case()) {
213     case TreeNode::kLeaf: {
214       return {};
215     }
216     case TreeNode::kDenseFloatBinarySplit: {
217       const auto& split = node.dense_float_binary_split();
218       return {split.left_id(), split.right_id()};
219     }
220     case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
221       const auto& split = node.sparse_float_binary_split_default_left().split();
222       return {split.left_id(), split.right_id()};
223     }
224     case TreeNode::kSparseFloatBinarySplitDefaultRight: {
225       const auto& split =
226           node.sparse_float_binary_split_default_right().split();
227       return {split.left_id(), split.right_id()};
228     }
229     case TreeNode::kCategoricalIdBinarySplit: {
230       const auto& split = node.categorical_id_binary_split();
231       return {split.left_id(), split.right_id()};
232     }
233     case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
234       const auto& split = node.categorical_id_set_membership_binary_split();
235       return {split.left_id(), split.right_id()};
236     }
237     case TreeNode::kObliviousDenseFloatBinarySplit: {
238       LOG(QFATAL)
239           << "Not implemented for the ObliviousDenseFloatBinarySplit case.";
240       return {};
241     }
242     case TreeNode::kObliviousCategoricalIdBinarySplit: {
243       LOG(QFATAL)
244           << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
245       break;
246     }
247     case TreeNode::NODE_NOT_SET: {
248       return {};
249     }
250   }
251 }
252 
253 }  // namespace trees
254 }  // namespace boosted_trees
255 }  // namespace tensorflow
256