1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 #include "tensorflow/core/framework/tensor_shape.h"
18 #include "tensorflow/core/kernels/boosted_trees/resources.h"
19 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
20 
21 namespace tensorflow {
22 
23 namespace {
24 constexpr float kLayerByLayerTreeWeight = 1.0;
25 constexpr float kMinDeltaForCenterBias = 0.01;
26 
27 // TODO(nponomareva, youngheek): consider using vector.
28 struct SplitCandidate {
SplitCandidatetensorflow::__anon8ce601b20111::SplitCandidate29   SplitCandidate() {}
30 
31   // Index in the list of the feature ids.
32   int64 feature_idx;
33 
34   // Index in the tensor of node_ids for the feature with idx feature_idx.
35   int64 candidate_idx;
36 
37   float gain;
38 };
39 
40 enum PruningMode { kNoPruning = 0, kPrePruning = 1, kPostPruning = 2 };
41 
42 }  // namespace
43 
44 class BoostedTreesUpdateEnsembleOp : public OpKernel {
45  public:
BoostedTreesUpdateEnsembleOp(OpKernelConstruction * const context)46   explicit BoostedTreesUpdateEnsembleOp(OpKernelConstruction* const context)
47       : OpKernel(context) {
48     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
49 
50     int32 pruning_index;
51     OP_REQUIRES_OK(context, context->GetAttr("pruning_mode", &pruning_index));
52     pruning_mode_ = static_cast<PruningMode>(pruning_index);
53   }
54 
Compute(OpKernelContext * const context)55   void Compute(OpKernelContext* const context) override {
56     // Get decision tree ensemble.
57     BoostedTreesEnsembleResource* ensemble_resource;
58     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
59                                            &ensemble_resource));
60     core::ScopedUnref unref_me(ensemble_resource);
61     mutex_lock l(*ensemble_resource->get_mutex());
62     // Increase the ensemble stamp.
63     ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
64 
65     // Read node ids, gains, thresholds and node contribs.
66     OpInputList node_ids_list;
67     OpInputList gains_list;
68     OpInputList thresholds_list;
69     OpInputList left_node_contribs;
70     OpInputList right_node_contribs;
71     OP_REQUIRES_OK(context, context->input_list("node_ids", &node_ids_list));
72     OP_REQUIRES_OK(context, context->input_list("gains", &gains_list));
73     OP_REQUIRES_OK(context,
74                    context->input_list("thresholds", &thresholds_list));
75     OP_REQUIRES_OK(context, context->input_list("left_node_contribs",
76                                                 &left_node_contribs));
77     OP_REQUIRES_OK(context, context->input_list("right_node_contribs",
78                                                 &right_node_contribs));
79 
80     const Tensor* feature_ids_t;
81     OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
82     const auto feature_ids = feature_ids_t->vec<int32>();
83 
84     const Tensor* max_depth_t;
85     OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
86     const auto max_depth = max_depth_t->scalar<int32>()();
87 
88     const Tensor* learning_rate_t;
89     OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
90     const auto learning_rate = learning_rate_t->scalar<float>()();
91 
92     // Find best splits for each active node.
93     std::map<int32, SplitCandidate> best_splits;
94     FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids,
95                           &best_splits);
96 
97     int32 current_tree =
98         UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
99 
100     // No-op if no new splits can be considered.
101     if (best_splits.empty()) {
102       LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
103       return;
104     }
105 
106     const int32 new_num_layers =
107         ensemble_resource->GetNumLayersGrown(current_tree) + 1;
108     VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #"
109             << current_tree << " of ensemble of " << current_tree + 1
110             << " trees.";
111     bool split_happened = false;
112     int32 node_id_start = ensemble_resource->GetNumNodes(current_tree);
113     // Add the splits to the tree.
114     for (auto& split_entry : best_splits) {
115       const int32 node_id = split_entry.first;
116       const SplitCandidate& candidate = split_entry.second;
117 
118       const int64 feature_idx = candidate.feature_idx;
119       const int64 candidate_idx = candidate.candidate_idx;
120 
121       const int32 feature_id = feature_ids(feature_idx);
122       const int32 threshold =
123           thresholds_list[feature_idx].vec<int32>()(candidate_idx);
124       const float gain = gains_list[feature_idx].vec<float>()(candidate_idx);
125 
126       if (pruning_mode_ == kPrePruning) {
127         // Don't consider negative splits if we're pre-pruning the tree.
128         // Note that zero-gain splits are acceptable.
129         if (gain < 0) {
130           continue;
131         }
132       }
133       // For now assume that the weights vectors are one dimensional.
134       // TODO(nponomareva): change here for multiclass.
135       const float left_contrib =
136           learning_rate *
137           left_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
138       const float right_contrib =
139           learning_rate *
140           right_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
141 
142       // unused.
143       int32 left_node_id;
144       int32 right_node_id;
145 
146       ensemble_resource->AddBucketizedSplitNode(
147           current_tree, node_id, feature_id, threshold, gain, left_contrib,
148           right_contrib, &left_node_id, &right_node_id);
149       split_happened = true;
150     }
151     int32 node_id_end = ensemble_resource->GetNumNodes(current_tree);
152     if (split_happened) {
153       // Update growable tree metadata.
154       ensemble_resource->SetNumLayersGrown(current_tree, new_num_layers);
155       // Finalize the tree if needed.
156       if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth) {
157         // If the tree is finalized, next growing will start from node 0;
158         node_id_start = 0;
159         node_id_end = 1;
160         ensemble_resource->SetIsFinalized(current_tree, true);
161         if (pruning_mode_ == kPostPruning) {
162           ensemble_resource->PostPruneTree(current_tree);
163         }
164         if (ensemble_resource->num_trees() > 0) {
165           // Create a dummy new tree with an empty node.
166           ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
167         }
168       }
169       // If we managed to split, update the node range. If we didn't, don't
170       // update as we will try to split the same nodes with new instances.
171       ensemble_resource->UpdateLastLayerNodesRange(node_id_start, node_id_end);
172     }
173   }
174 
175  private:
UpdateGlobalAttemptsAndRetrieveGrowableTree(BoostedTreesEnsembleResource * const ensemble_resource)176   int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
177       BoostedTreesEnsembleResource* const ensemble_resource) {
178     int32 num_trees = ensemble_resource->num_trees();
179     int32 current_tree = num_trees - 1;
180 
181     // Increment global attempt stats.
182     ensemble_resource->UpdateGrowingMetadata();
183 
184     // Note we don't set tree weight to be equal to learning rate, since we
185     // apply learning rate to leaf weights instead, when doing layer-by-layer
186     // boosting.
187     if (num_trees <= 0) {
188       // Create a new tree with a no-op leaf.
189       current_tree = ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
190     }
191     return current_tree;
192   }
193 
194   // Helper method which effectively does a reduce over all split candidates
195   // and finds the best split for each node.
FindBestSplitsPerNode(OpKernelContext * const context,const OpInputList & node_ids_list,const OpInputList & gains_list,const TTypes<const int32>::Vec & feature_ids,std::map<int32,SplitCandidate> * best_split_per_node)196   void FindBestSplitsPerNode(
197       OpKernelContext* const context, const OpInputList& node_ids_list,
198       const OpInputList& gains_list,
199       const TTypes<const int32>::Vec& feature_ids,
200       std::map<int32, SplitCandidate>* best_split_per_node) {
201     // Find best split per node going through every feature candidate.
202     for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
203       const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
204       const auto& gains = gains_list[feature_idx].vec<float>();
205 
206       for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
207            ++candidate_idx) {
208         // Get current split candidate.
209         const auto& node_id = node_ids(candidate_idx);
210         const auto& gain = gains(candidate_idx);
211 
212         auto best_split_it = best_split_per_node->find(node_id);
213         SplitCandidate candidate;
214         candidate.feature_idx = feature_idx;
215         candidate.candidate_idx = candidate_idx;
216         candidate.gain = gain;
217 
218         if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
219                              GainsAreEqual(gain, best_split_it->second.gain))) {
220           const auto best_candidate = (*best_split_per_node)[node_id];
221           const int32 best_feature_id = feature_ids(best_candidate.feature_idx);
222           const int32 feature_id = feature_ids(candidate.feature_idx);
223           VLOG(2) << "Breaking ties on feature ids and buckets";
224           // Breaking ties deterministically.
225           if (feature_id < best_feature_id) {
226             (*best_split_per_node)[node_id] = candidate;
227           }
228         } else if (best_split_it == best_split_per_node->end() ||
229                    GainIsLarger(gain, best_split_it->second.gain)) {
230           (*best_split_per_node)[node_id] = candidate;
231         }
232       }
233     }
234   }
235 
236  private:
237   int32 num_features_;
238   PruningMode pruning_mode_;
239 };
240 
241 REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU),
242                         BoostedTreesUpdateEnsembleOp);
243 
244 class BoostedTreesCenterBiasOp : public OpKernel {
245  public:
BoostedTreesCenterBiasOp(OpKernelConstruction * const context)246   explicit BoostedTreesCenterBiasOp(OpKernelConstruction* const context)
247       : OpKernel(context) {}
248 
Compute(OpKernelContext * const context)249   void Compute(OpKernelContext* const context) override {
250     // Get decision tree ensemble.
251     BoostedTreesEnsembleResource* ensemble_resource;
252     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
253                                            &ensemble_resource));
254     core::ScopedUnref unref_me(ensemble_resource);
255     mutex_lock l(*ensemble_resource->get_mutex());
256     // Increase the ensemble stamp.
257     ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
258 
259     // Read means of hessians and gradients
260     const Tensor* mean_gradients_t;
261     OP_REQUIRES_OK(context,
262                    context->input("mean_gradients", &mean_gradients_t));
263 
264     const Tensor* mean_hessians_t;
265     OP_REQUIRES_OK(context, context->input("mean_hessians", &mean_hessians_t));
266 
267     // Get the regularization options.
268     const Tensor* l1_t;
269     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
270     const auto l1 = l1_t->scalar<float>()();
271     const Tensor* l2_t;
272     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
273     const auto l2 = l2_t->scalar<float>()();
274 
275     // For now, assume 1-dimensional weight on leaves.
276     float logits;
277     float unused_gain;
278 
279     // TODO(nponomareva): change this when supporting multiclass.
280     const float gradients_mean = mean_gradients_t->flat<float>()(0);
281     const float hessians_mean = mean_hessians_t->flat<float>()(0);
282     CalculateWeightsAndGains(gradients_mean, hessians_mean, l1, l2, &logits,
283                              &unused_gain);
284 
285     float current_bias = 0.0;
286     bool continue_centering = true;
287     if (ensemble_resource->num_trees() == 0) {
288       ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, logits);
289       current_bias = logits;
290     } else {
291       const auto& current_biases = ensemble_resource->node_value(0, 0);
292       DCHECK_EQ(current_biases.size(), 1);
293       current_bias = current_biases[0];
294       continue_centering =
295           std::abs(logits / current_bias) > kMinDeltaForCenterBias;
296       current_bias += logits;
297       ensemble_resource->set_node_value(0, 0, current_bias);
298     }
299 
300     Tensor* continue_centering_t = nullptr;
301     OP_REQUIRES_OK(
302         context, context->allocate_output("continue_centering", TensorShape({}),
303                                           &continue_centering_t));
304     // Check if we need to continue centering bias.
305     continue_centering_t->scalar<bool>()() = continue_centering;
306   }
307 };
308 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCenterBias").Device(DEVICE_CPU),
309                         BoostedTreesCenterBiasOp);
310 
311 }  // namespace tensorflow
312