• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
16 
17 #include "tensorflow/core/lib/random/philox_random.h"
18 #include "tensorflow/core/lib/random/simple_philox.h"
19 #include "tensorflow/core/platform/logging.h"
20 
21 namespace tensorflow {
22 namespace boosted_trees {
23 namespace testutil {
24 
25 using boosted_trees::trees::DenseFloatBinarySplit;
26 using tensorflow::boosted_trees::trees::DecisionTreeConfig;
27 using tensorflow::boosted_trees::trees::TreeNode;
28 
29 namespace {
30 
31 // Append the given nodes to tree with transfer of pointer ownership.
32 // nodes will not be usable upon return.
33 template <typename T>
AppendNodes(DecisionTreeConfig * tree,T * nodes)34 void AppendNodes(DecisionTreeConfig* tree, T* nodes) {
35   std::reverse(nodes->pointer_begin(), nodes->pointer_end());
36   while (!nodes->empty()) {
37     tree->mutable_nodes()->AddAllocated(nodes->ReleaseLast());
38   }
39 }
40 
GetSplit(TreeNode * node)41 DenseFloatBinarySplit* GetSplit(TreeNode* node) {
42   switch (node->node_case()) {
43     case TreeNode::kSparseFloatBinarySplitDefaultLeft:
44       return node->mutable_sparse_float_binary_split_default_left()
45           ->mutable_split();
46     case TreeNode::kSparseFloatBinarySplitDefaultRight:
47       return node->mutable_sparse_float_binary_split_default_right()
48           ->mutable_split();
49     case TreeNode::kDenseFloatBinarySplit:
50       return node->mutable_dense_float_binary_split();
51     default:
52       LOG(FATAL) << "Unknown node type encountered.";
53   }
54   return nullptr;
55 }
56 
57 }  // namespace
58 
RandomTreeGen(tensorflow::random::SimplePhilox * rng,int dense_feature_size,int sparse_feature_size)59 RandomTreeGen::RandomTreeGen(tensorflow::random::SimplePhilox* rng,
60                              int dense_feature_size, int sparse_feature_size)
61     : rng_(rng),
62       dense_feature_size_(dense_feature_size),
63       sparse_feature_size_(sparse_feature_size) {}
64 
65 namespace {
AddWeightAndMetadata(boosted_trees::trees::DecisionTreeEnsembleConfig * ret)66 void AddWeightAndMetadata(
67     boosted_trees::trees::DecisionTreeEnsembleConfig* ret) {
68   // Assign the weight of the tree to 1 and say that this weight was updated
69   // only once.
70   ret->add_tree_weights(1.0);
71   auto* meta = ret->add_tree_metadata();
72   meta->set_num_tree_weight_updates(1);
73 }
74 
75 }  //  namespace
76 
77 boosted_trees::trees::DecisionTreeEnsembleConfig
GenerateEnsemble(int depth,int tree_count)78 RandomTreeGen::GenerateEnsemble(int depth, int tree_count) {
79   boosted_trees::trees::DecisionTreeEnsembleConfig ret;
80   *(ret.add_trees()) = Generate(depth);
81   AddWeightAndMetadata(&ret);
82   for (int i = 1; i < tree_count; ++i) {
83     *(ret.add_trees()) = Generate(ret.trees(0));
84     AddWeightAndMetadata(&ret);
85   }
86   return ret;
87 }
88 
Generate(const DecisionTreeConfig & tree)89 DecisionTreeConfig RandomTreeGen::Generate(const DecisionTreeConfig& tree) {
90   DecisionTreeConfig ret = tree;
91   for (auto& node : *ret.mutable_nodes()) {
92     if (node.node_case() == TreeNode::kLeaf) {
93       node.mutable_leaf()->mutable_sparse_vector()->set_value(
94           0, rng_->RandFloat());
95       continue;
96     }
97     // Original node is a split. Re-generate it's type but retain the split node
98     // indices.
99     DenseFloatBinarySplit* split = GetSplit(&node);
100     const int left_id = split->left_id();
101     const int right_id = split->right_id();
102     GenerateSplit(&node, left_id, right_id);
103   }
104   return ret;
105 }
106 
Generate(int depth)107 DecisionTreeConfig RandomTreeGen::Generate(int depth) {
108   DecisionTreeConfig ret;
109   // Add root,
110   TreeNode* node = ret.add_nodes();
111   GenerateSplit(node, 1, 2);
112   if (depth == 1) {
113     // Add left and right leaves.
114     TreeNode* left = ret.add_nodes();
115     left->mutable_leaf()->mutable_sparse_vector()->add_index(0);
116     left->mutable_leaf()->mutable_sparse_vector()->add_value(rng_->RandFloat());
117     TreeNode* right = ret.add_nodes();
118     right->mutable_leaf()->mutable_sparse_vector()->add_index(0);
119     right->mutable_leaf()->mutable_sparse_vector()->add_value(
120         rng_->RandFloat());
121     return ret;
122   } else {
123     DecisionTreeConfig left_branch = Generate(depth - 1);
124     DecisionTreeConfig right_branch = Generate(depth - 1);
125     Combine(&ret, &left_branch, &right_branch);
126     return ret;
127   }
128 }
129 
Combine(DecisionTreeConfig * root,DecisionTreeConfig * left_branch,DecisionTreeConfig * right_branch)130 void RandomTreeGen::Combine(DecisionTreeConfig* root,
131                             DecisionTreeConfig* left_branch,
132                             DecisionTreeConfig* right_branch) {
133   const int left_branch_size = left_branch->nodes_size();
134   CHECK_EQ(1, root->nodes_size());
135   // left_branch starts its index at 1. right_branch starts its index at
136   // (left_branch_size + 1).
137   auto* root_node = root->mutable_nodes(0);
138   DenseFloatBinarySplit* root_split = GetSplit(root_node);
139   root_split->set_left_id(1);
140   root_split->set_right_id(left_branch_size + 1);
141   // Shift left/right branch's indices internally so that everything is
142   // consistent.
143   ShiftNodeIndex(left_branch, 1);
144   ShiftNodeIndex(right_branch, left_branch_size + 1);
145 
146   // Complexity O(branch node size). No proto copying though.
147   AppendNodes(root, left_branch->mutable_nodes());
148   AppendNodes(root, right_branch->mutable_nodes());
149 }
150 
ShiftNodeIndex(DecisionTreeConfig * tree,int shift)151 void RandomTreeGen::ShiftNodeIndex(DecisionTreeConfig* tree, int shift) {
152   for (TreeNode& node : *(tree->mutable_nodes())) {
153     DenseFloatBinarySplit* split = nullptr;
154     switch (node.node_case()) {
155       case TreeNode::kLeaf:
156         break;
157       case TreeNode::kSparseFloatBinarySplitDefaultLeft:
158         split = node.mutable_sparse_float_binary_split_default_left()
159                     ->mutable_split();
160         break;
161       case TreeNode::kSparseFloatBinarySplitDefaultRight:
162         split = node.mutable_sparse_float_binary_split_default_right()
163                     ->mutable_split();
164         break;
165       case TreeNode::kDenseFloatBinarySplit:
166         split = node.mutable_dense_float_binary_split();
167         break;
168       default:
169         LOG(FATAL) << "Unknown node type encountered.";
170     }
171     if (split) {
172       split->set_left_id(shift + split->left_id());
173       split->set_right_id(shift + split->right_id());
174     }
175   }
176 }
177 
GenerateSplit(TreeNode * node,int left_id,int right_id)178 void RandomTreeGen::GenerateSplit(TreeNode* node, int left_id, int right_id) {
179   const double denseSplitProb =
180       sparse_feature_size_ == 0
181           ? 1.0
182           : static_cast<double>(dense_feature_size_) /
183                 (dense_feature_size_ + sparse_feature_size_);
184   // Generate the tree such that it has equal probability of going left and
185   // right when the feature is missing.
186   static constexpr float kLeftProb = 0.5;
187 
188   DenseFloatBinarySplit* split;
189   int feature_size;
190   if (rng_->RandFloat() < denseSplitProb) {
191     feature_size = dense_feature_size_;
192     split = node->mutable_dense_float_binary_split();
193   } else {
194     feature_size = sparse_feature_size_;
195     if (rng_->RandFloat() < kLeftProb) {
196       split = node->mutable_sparse_float_binary_split_default_left()
197                   ->mutable_split();
198     } else {
199       split = node->mutable_sparse_float_binary_split_default_right()
200                   ->mutable_split();
201     }
202   }
203   split->set_threshold(rng_->RandFloat());
204   split->set_feature_column(rng_->Uniform(feature_size));
205   split->set_left_id(left_id);
206   split->set_right_id(right_id);
207 }
208 
209 }  // namespace testutil
210 }  // namespace boosted_trees
211 }  // namespace tensorflow
212