1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
16
17 #include "tensorflow/core/lib/random/philox_random.h"
18 #include "tensorflow/core/lib/random/simple_philox.h"
19 #include "tensorflow/core/platform/logging.h"
20
21 namespace tensorflow {
22 namespace boosted_trees {
23 namespace testutil {
24
25 using boosted_trees::trees::DenseFloatBinarySplit;
26 using tensorflow::boosted_trees::trees::DecisionTreeConfig;
27 using tensorflow::boosted_trees::trees::TreeNode;
28
29 namespace {
30
31 // Append the given nodes to tree with transfer of pointer ownership.
32 // nodes will not be usable upon return.
33 template <typename T>
AppendNodes(DecisionTreeConfig * tree,T * nodes)34 void AppendNodes(DecisionTreeConfig* tree, T* nodes) {
35 std::reverse(nodes->pointer_begin(), nodes->pointer_end());
36 while (!nodes->empty()) {
37 tree->mutable_nodes()->AddAllocated(nodes->ReleaseLast());
38 }
39 }
40
GetSplit(TreeNode * node)41 DenseFloatBinarySplit* GetSplit(TreeNode* node) {
42 switch (node->node_case()) {
43 case TreeNode::kSparseFloatBinarySplitDefaultLeft:
44 return node->mutable_sparse_float_binary_split_default_left()
45 ->mutable_split();
46 case TreeNode::kSparseFloatBinarySplitDefaultRight:
47 return node->mutable_sparse_float_binary_split_default_right()
48 ->mutable_split();
49 case TreeNode::kDenseFloatBinarySplit:
50 return node->mutable_dense_float_binary_split();
51 default:
52 LOG(FATAL) << "Unknown node type encountered.";
53 }
54 return nullptr;
55 }
56
57 } // namespace
58
RandomTreeGen(tensorflow::random::SimplePhilox * rng,int dense_feature_size,int sparse_feature_size)59 RandomTreeGen::RandomTreeGen(tensorflow::random::SimplePhilox* rng,
60 int dense_feature_size, int sparse_feature_size)
61 : rng_(rng),
62 dense_feature_size_(dense_feature_size),
63 sparse_feature_size_(sparse_feature_size) {}
64
65 namespace {
AddWeightAndMetadata(boosted_trees::trees::DecisionTreeEnsembleConfig * ret)66 void AddWeightAndMetadata(
67 boosted_trees::trees::DecisionTreeEnsembleConfig* ret) {
68 // Assign the weight of the tree to 1 and say that this weight was updated
69 // only once.
70 ret->add_tree_weights(1.0);
71 auto* meta = ret->add_tree_metadata();
72 meta->set_num_tree_weight_updates(1);
73 }
74
75 } // namespace
76
77 boosted_trees::trees::DecisionTreeEnsembleConfig
GenerateEnsemble(int depth,int tree_count)78 RandomTreeGen::GenerateEnsemble(int depth, int tree_count) {
79 boosted_trees::trees::DecisionTreeEnsembleConfig ret;
80 *(ret.add_trees()) = Generate(depth);
81 AddWeightAndMetadata(&ret);
82 for (int i = 1; i < tree_count; ++i) {
83 *(ret.add_trees()) = Generate(ret.trees(0));
84 AddWeightAndMetadata(&ret);
85 }
86 return ret;
87 }
88
Generate(const DecisionTreeConfig & tree)89 DecisionTreeConfig RandomTreeGen::Generate(const DecisionTreeConfig& tree) {
90 DecisionTreeConfig ret = tree;
91 for (auto& node : *ret.mutable_nodes()) {
92 if (node.node_case() == TreeNode::kLeaf) {
93 node.mutable_leaf()->mutable_sparse_vector()->set_value(
94 0, rng_->RandFloat());
95 continue;
96 }
97 // Original node is a split. Re-generate it's type but retain the split node
98 // indices.
99 DenseFloatBinarySplit* split = GetSplit(&node);
100 const int left_id = split->left_id();
101 const int right_id = split->right_id();
102 GenerateSplit(&node, left_id, right_id);
103 }
104 return ret;
105 }
106
Generate(int depth)107 DecisionTreeConfig RandomTreeGen::Generate(int depth) {
108 DecisionTreeConfig ret;
109 // Add root,
110 TreeNode* node = ret.add_nodes();
111 GenerateSplit(node, 1, 2);
112 if (depth == 1) {
113 // Add left and right leaves.
114 TreeNode* left = ret.add_nodes();
115 left->mutable_leaf()->mutable_sparse_vector()->add_index(0);
116 left->mutable_leaf()->mutable_sparse_vector()->add_value(rng_->RandFloat());
117 TreeNode* right = ret.add_nodes();
118 right->mutable_leaf()->mutable_sparse_vector()->add_index(0);
119 right->mutable_leaf()->mutable_sparse_vector()->add_value(
120 rng_->RandFloat());
121 return ret;
122 } else {
123 DecisionTreeConfig left_branch = Generate(depth - 1);
124 DecisionTreeConfig right_branch = Generate(depth - 1);
125 Combine(&ret, &left_branch, &right_branch);
126 return ret;
127 }
128 }
129
Combine(DecisionTreeConfig * root,DecisionTreeConfig * left_branch,DecisionTreeConfig * right_branch)130 void RandomTreeGen::Combine(DecisionTreeConfig* root,
131 DecisionTreeConfig* left_branch,
132 DecisionTreeConfig* right_branch) {
133 const int left_branch_size = left_branch->nodes_size();
134 CHECK_EQ(1, root->nodes_size());
135 // left_branch starts its index at 1. right_branch starts its index at
136 // (left_branch_size + 1).
137 auto* root_node = root->mutable_nodes(0);
138 DenseFloatBinarySplit* root_split = GetSplit(root_node);
139 root_split->set_left_id(1);
140 root_split->set_right_id(left_branch_size + 1);
141 // Shift left/right branch's indices internally so that everything is
142 // consistent.
143 ShiftNodeIndex(left_branch, 1);
144 ShiftNodeIndex(right_branch, left_branch_size + 1);
145
146 // Complexity O(branch node size). No proto copying though.
147 AppendNodes(root, left_branch->mutable_nodes());
148 AppendNodes(root, right_branch->mutable_nodes());
149 }
150
ShiftNodeIndex(DecisionTreeConfig * tree,int shift)151 void RandomTreeGen::ShiftNodeIndex(DecisionTreeConfig* tree, int shift) {
152 for (TreeNode& node : *(tree->mutable_nodes())) {
153 DenseFloatBinarySplit* split = nullptr;
154 switch (node.node_case()) {
155 case TreeNode::kLeaf:
156 break;
157 case TreeNode::kSparseFloatBinarySplitDefaultLeft:
158 split = node.mutable_sparse_float_binary_split_default_left()
159 ->mutable_split();
160 break;
161 case TreeNode::kSparseFloatBinarySplitDefaultRight:
162 split = node.mutable_sparse_float_binary_split_default_right()
163 ->mutable_split();
164 break;
165 case TreeNode::kDenseFloatBinarySplit:
166 split = node.mutable_dense_float_binary_split();
167 break;
168 default:
169 LOG(FATAL) << "Unknown node type encountered.";
170 }
171 if (split) {
172 split->set_left_id(shift + split->left_id());
173 split->set_right_id(shift + split->right_id());
174 }
175 }
176 }
177
GenerateSplit(TreeNode * node,int left_id,int right_id)178 void RandomTreeGen::GenerateSplit(TreeNode* node, int left_id, int right_id) {
179 const double denseSplitProb =
180 sparse_feature_size_ == 0
181 ? 1.0
182 : static_cast<double>(dense_feature_size_) /
183 (dense_feature_size_ + sparse_feature_size_);
184 // Generate the tree such that it has equal probability of going left and
185 // right when the feature is missing.
186 static constexpr float kLeftProb = 0.5;
187
188 DenseFloatBinarySplit* split;
189 int feature_size;
190 if (rng_->RandFloat() < denseSplitProb) {
191 feature_size = dense_feature_size_;
192 split = node->mutable_dense_float_binary_split();
193 } else {
194 feature_size = sparse_feature_size_;
195 if (rng_->RandFloat() < kLeftProb) {
196 split = node->mutable_sparse_float_binary_split_default_left()
197 ->mutable_split();
198 } else {
199 split = node->mutable_sparse_float_binary_split_default_right()
200 ->mutable_split();
201 }
202 }
203 split->set_threshold(rng_->RandFloat());
204 split->set_feature_column(rng_->Uniform(feature_size));
205 split->set_left_id(left_id);
206 split->set_right_id(right_id);
207 }
208
209 } // namespace testutil
210 } // namespace boosted_trees
211 } // namespace tensorflow
212