1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_kernel.h" 17 #include "tensorflow/core/framework/tensor_shape.h" 18 #include "tensorflow/core/kernels/boosted_trees/resources.h" 19 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h" 20 21 namespace tensorflow { 22 23 namespace { 24 constexpr float kLayerByLayerTreeWeight = 1.0; 25 constexpr float kMinDeltaForCenterBias = 0.01; 26 27 // TODO(nponomareva, youngheek): consider using vector. 28 struct SplitCandidate { SplitCandidatetensorflow::__anon8ce601b20111::SplitCandidate29 SplitCandidate() {} 30 31 // Index in the list of the feature ids. 32 int64 feature_idx; 33 34 // Index in the tensor of node_ids for the feature with idx feature_idx. 35 int64 candidate_idx; 36 37 float gain; 38 }; 39 40 enum PruningMode { kNoPruning = 0, kPrePruning = 1, kPostPruning = 2 }; 41 42 } // namespace 43 44 class BoostedTreesUpdateEnsembleOp : public OpKernel { 45 public: BoostedTreesUpdateEnsembleOp(OpKernelConstruction * const context)46 explicit BoostedTreesUpdateEnsembleOp(OpKernelConstruction* const context) 47 : OpKernel(context) { 48 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_)); 49 50 int32 pruning_index; 51 OP_REQUIRES_OK(context, context->GetAttr("pruning_mode", &pruning_index)); 52 pruning_mode_ = static_cast<PruningMode>(pruning_index); 53 } 54 Compute(OpKernelContext * const context)55 void Compute(OpKernelContext* const context) override { 56 // Get decision tree ensemble. 57 BoostedTreesEnsembleResource* ensemble_resource; 58 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 59 &ensemble_resource)); 60 core::ScopedUnref unref_me(ensemble_resource); 61 mutex_lock l(*ensemble_resource->get_mutex()); 62 // Increase the ensemble stamp. 63 ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); 64 65 // Read node ids, gains, thresholds and node contribs. 66 OpInputList node_ids_list; 67 OpInputList gains_list; 68 OpInputList thresholds_list; 69 OpInputList left_node_contribs; 70 OpInputList right_node_contribs; 71 OP_REQUIRES_OK(context, context->input_list("node_ids", &node_ids_list)); 72 OP_REQUIRES_OK(context, context->input_list("gains", &gains_list)); 73 OP_REQUIRES_OK(context, 74 context->input_list("thresholds", &thresholds_list)); 75 OP_REQUIRES_OK(context, context->input_list("left_node_contribs", 76 &left_node_contribs)); 77 OP_REQUIRES_OK(context, context->input_list("right_node_contribs", 78 &right_node_contribs)); 79 80 const Tensor* feature_ids_t; 81 OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t)); 82 const auto feature_ids = feature_ids_t->vec<int32>(); 83 84 const Tensor* max_depth_t; 85 OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t)); 86 const auto max_depth = max_depth_t->scalar<int32>()(); 87 88 const Tensor* learning_rate_t; 89 OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t)); 90 const auto learning_rate = learning_rate_t->scalar<float>()(); 91 92 // Find best splits for each active node. 93 std::map<int32, SplitCandidate> best_splits; 94 FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids, 95 &best_splits); 96 97 int32 current_tree = 98 UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource); 99 100 // No-op if no new splits can be considered. 101 if (best_splits.empty()) { 102 LOG(WARNING) << "Not growing tree ensemble as no good splits were found."; 103 return; 104 } 105 106 const int32 new_num_layers = 107 ensemble_resource->GetNumLayersGrown(current_tree) + 1; 108 VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #" 109 << current_tree << " of ensemble of " << current_tree + 1 110 << " trees."; 111 bool split_happened = false; 112 int32 node_id_start = ensemble_resource->GetNumNodes(current_tree); 113 // Add the splits to the tree. 114 for (auto& split_entry : best_splits) { 115 const int32 node_id = split_entry.first; 116 const SplitCandidate& candidate = split_entry.second; 117 118 const int64 feature_idx = candidate.feature_idx; 119 const int64 candidate_idx = candidate.candidate_idx; 120 121 const int32 feature_id = feature_ids(feature_idx); 122 const int32 threshold = 123 thresholds_list[feature_idx].vec<int32>()(candidate_idx); 124 const float gain = gains_list[feature_idx].vec<float>()(candidate_idx); 125 126 if (pruning_mode_ == kPrePruning) { 127 // Don't consider negative splits if we're pre-pruning the tree. 128 // Note that zero-gain splits are acceptable. 129 if (gain < 0) { 130 continue; 131 } 132 } 133 // For now assume that the weights vectors are one dimensional. 134 // TODO(nponomareva): change here for multiclass. 135 const float left_contrib = 136 learning_rate * 137 left_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0); 138 const float right_contrib = 139 learning_rate * 140 right_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0); 141 142 // unused. 143 int32 left_node_id; 144 int32 right_node_id; 145 146 ensemble_resource->AddBucketizedSplitNode( 147 current_tree, node_id, feature_id, threshold, gain, left_contrib, 148 right_contrib, &left_node_id, &right_node_id); 149 split_happened = true; 150 } 151 int32 node_id_end = ensemble_resource->GetNumNodes(current_tree); 152 if (split_happened) { 153 // Update growable tree metadata. 154 ensemble_resource->SetNumLayersGrown(current_tree, new_num_layers); 155 // Finalize the tree if needed. 156 if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth) { 157 // If the tree is finalized, next growing will start from node 0; 158 node_id_start = 0; 159 node_id_end = 1; 160 ensemble_resource->SetIsFinalized(current_tree, true); 161 if (pruning_mode_ == kPostPruning) { 162 ensemble_resource->PostPruneTree(current_tree); 163 } 164 if (ensemble_resource->num_trees() > 0) { 165 // Create a dummy new tree with an empty node. 166 ensemble_resource->AddNewTree(kLayerByLayerTreeWeight); 167 } 168 } 169 // If we managed to split, update the node range. If we didn't, don't 170 // update as we will try to split the same nodes with new instances. 171 ensemble_resource->UpdateLastLayerNodesRange(node_id_start, node_id_end); 172 } 173 } 174 175 private: UpdateGlobalAttemptsAndRetrieveGrowableTree(BoostedTreesEnsembleResource * const ensemble_resource)176 int32 UpdateGlobalAttemptsAndRetrieveGrowableTree( 177 BoostedTreesEnsembleResource* const ensemble_resource) { 178 int32 num_trees = ensemble_resource->num_trees(); 179 int32 current_tree = num_trees - 1; 180 181 // Increment global attempt stats. 182 ensemble_resource->UpdateGrowingMetadata(); 183 184 // Note we don't set tree weight to be equal to learning rate, since we 185 // apply learning rate to leaf weights instead, when doing layer-by-layer 186 // boosting. 187 if (num_trees <= 0) { 188 // Create a new tree with a no-op leaf. 189 current_tree = ensemble_resource->AddNewTree(kLayerByLayerTreeWeight); 190 } 191 return current_tree; 192 } 193 194 // Helper method which effectively does a reduce over all split candidates 195 // and finds the best split for each node. FindBestSplitsPerNode(OpKernelContext * const context,const OpInputList & node_ids_list,const OpInputList & gains_list,const TTypes<const int32>::Vec & feature_ids,std::map<int32,SplitCandidate> * best_split_per_node)196 void FindBestSplitsPerNode( 197 OpKernelContext* const context, const OpInputList& node_ids_list, 198 const OpInputList& gains_list, 199 const TTypes<const int32>::Vec& feature_ids, 200 std::map<int32, SplitCandidate>* best_split_per_node) { 201 // Find best split per node going through every feature candidate. 202 for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) { 203 const auto& node_ids = node_ids_list[feature_idx].vec<int32>(); 204 const auto& gains = gains_list[feature_idx].vec<float>(); 205 206 for (size_t candidate_idx = 0; candidate_idx < node_ids.size(); 207 ++candidate_idx) { 208 // Get current split candidate. 209 const auto& node_id = node_ids(candidate_idx); 210 const auto& gain = gains(candidate_idx); 211 212 auto best_split_it = best_split_per_node->find(node_id); 213 SplitCandidate candidate; 214 candidate.feature_idx = feature_idx; 215 candidate.candidate_idx = candidate_idx; 216 candidate.gain = gain; 217 218 if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() && 219 GainsAreEqual(gain, best_split_it->second.gain))) { 220 const auto best_candidate = (*best_split_per_node)[node_id]; 221 const int32 best_feature_id = feature_ids(best_candidate.feature_idx); 222 const int32 feature_id = feature_ids(candidate.feature_idx); 223 VLOG(2) << "Breaking ties on feature ids and buckets"; 224 // Breaking ties deterministically. 225 if (feature_id < best_feature_id) { 226 (*best_split_per_node)[node_id] = candidate; 227 } 228 } else if (best_split_it == best_split_per_node->end() || 229 GainIsLarger(gain, best_split_it->second.gain)) { 230 (*best_split_per_node)[node_id] = candidate; 231 } 232 } 233 } 234 } 235 236 private: 237 int32 num_features_; 238 PruningMode pruning_mode_; 239 }; 240 241 REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU), 242 BoostedTreesUpdateEnsembleOp); 243 244 class BoostedTreesCenterBiasOp : public OpKernel { 245 public: BoostedTreesCenterBiasOp(OpKernelConstruction * const context)246 explicit BoostedTreesCenterBiasOp(OpKernelConstruction* const context) 247 : OpKernel(context) {} 248 Compute(OpKernelContext * const context)249 void Compute(OpKernelContext* const context) override { 250 // Get decision tree ensemble. 251 BoostedTreesEnsembleResource* ensemble_resource; 252 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 253 &ensemble_resource)); 254 core::ScopedUnref unref_me(ensemble_resource); 255 mutex_lock l(*ensemble_resource->get_mutex()); 256 // Increase the ensemble stamp. 257 ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); 258 259 // Read means of hessians and gradients 260 const Tensor* mean_gradients_t; 261 OP_REQUIRES_OK(context, 262 context->input("mean_gradients", &mean_gradients_t)); 263 264 const Tensor* mean_hessians_t; 265 OP_REQUIRES_OK(context, context->input("mean_hessians", &mean_hessians_t)); 266 267 // Get the regularization options. 268 const Tensor* l1_t; 269 OP_REQUIRES_OK(context, context->input("l1", &l1_t)); 270 const auto l1 = l1_t->scalar<float>()(); 271 const Tensor* l2_t; 272 OP_REQUIRES_OK(context, context->input("l2", &l2_t)); 273 const auto l2 = l2_t->scalar<float>()(); 274 275 // For now, assume 1-dimensional weight on leaves. 276 float logits; 277 float unused_gain; 278 279 // TODO(nponomareva): change this when supporting multiclass. 280 const float gradients_mean = mean_gradients_t->flat<float>()(0); 281 const float hessians_mean = mean_hessians_t->flat<float>()(0); 282 CalculateWeightsAndGains(gradients_mean, hessians_mean, l1, l2, &logits, 283 &unused_gain); 284 285 float current_bias = 0.0; 286 bool continue_centering = true; 287 if (ensemble_resource->num_trees() == 0) { 288 ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, logits); 289 current_bias = logits; 290 } else { 291 const auto& current_biases = ensemble_resource->node_value(0, 0); 292 DCHECK_EQ(current_biases.size(), 1); 293 current_bias = current_biases[0]; 294 continue_centering = 295 std::abs(logits / current_bias) > kMinDeltaForCenterBias; 296 current_bias += logits; 297 ensemble_resource->set_node_value(0, 0, current_bias); 298 } 299 300 Tensor* continue_centering_t = nullptr; 301 OP_REQUIRES_OK( 302 context, context->allocate_output("continue_centering", TensorShape({}), 303 &continue_centering_t)); 304 // Check if we need to continue centering bias. 305 continue_centering_t->scalar<bool>()() = continue_centering; 306 } 307 }; 308 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCenterBias").Device(DEVICE_CPU), 309 BoostedTreesCenterBiasOp); 310 311 } // namespace tensorflow 312