1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
21 
22 namespace tensorflow {
23 
24 class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
25  public:
BoostedTreesCalculateBestGainsPerFeatureOp(OpKernelConstruction * const context)26   explicit BoostedTreesCalculateBestGainsPerFeatureOp(
27       OpKernelConstruction* const context)
28       : OpKernel(context) {
29     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
30     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
31   }
32 
Compute(OpKernelContext * const context)33   void Compute(OpKernelContext* const context) override {
34     // node_id_range
35     const Tensor* node_id_range_t;
36     OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
37     const auto node_id_range = node_id_range_t->vec<int32>();
38     const int32 node_id_first = node_id_range(0);  // inclusive
39     const int32 node_id_last = node_id_range(1);   // exclusive
40     // stats_summary_list
41     OpInputList stats_summary_list;
42     OP_REQUIRES_OK(context, context->input_list("stats_summary_list",
43                                                 &stats_summary_list));
44     const int64 num_buckets = stats_summary_list[0].dim_size(1);
45     std::vector<TTypes<float, 3>::ConstTensor> stats_summary;
46     stats_summary.reserve(stats_summary_list.size());
47     for (const auto& tensor : stats_summary_list) {
48       stats_summary.emplace_back(tensor.tensor<float, 3>());
49     }
50     const Tensor* l1_t;
51     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
52     const auto l1 = l1_t->scalar<float>()();
53     const Tensor* l2_t;
54     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
55     const auto l2 = l2_t->scalar<float>()();
56     const Tensor* tree_complexity_t;
57     OP_REQUIRES_OK(context,
58                    context->input("tree_complexity", &tree_complexity_t));
59     const auto tree_complexity = tree_complexity_t->scalar<float>()();
60     const Tensor* min_node_weight_t;
61     OP_REQUIRES_OK(context,
62                    context->input("min_node_weight", &min_node_weight_t));
63     const auto min_node_weight = min_node_weight_t->scalar<float>()();
64 
65     // Allocate output lists of tensors:
66     OpOutputList output_node_ids_list;
67     OP_REQUIRES_OK(
68         context, context->output_list("node_ids_list", &output_node_ids_list));
69     OpOutputList output_gains_list;
70     OP_REQUIRES_OK(context,
71                    context->output_list("gains_list", &output_gains_list));
72     OpOutputList output_thresholds_list;
73     OP_REQUIRES_OK(context, context->output_list("thresholds_list",
74                                                  &output_thresholds_list));
75     OpOutputList output_left_node_contribs_list;
76     OP_REQUIRES_OK(context,
77                    context->output_list("left_node_contribs_list",
78                                         &output_left_node_contribs_list));
79     OpOutputList output_right_node_contribs_list;
80     OP_REQUIRES_OK(context,
81                    context->output_list("right_node_contribs_list",
82                                         &output_right_node_contribs_list));
83 
84     // Get the best split info per node for each feature.
85     for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
86       std::vector<float> cum_grad;
87       std::vector<float> cum_hess;
88       cum_grad.reserve(num_buckets);
89       cum_hess.reserve(num_buckets);
90 
91       std::vector<int32> output_node_ids;
92       std::vector<float> output_gains;
93       std::vector<int32> output_thresholds;
94       std::vector<float> output_left_node_contribs;
95       std::vector<float> output_right_node_contribs;
96       for (int node_id = node_id_first; node_id < node_id_last; ++node_id) {
97         // Calculate gains.
98         cum_grad.clear();
99         cum_hess.clear();
100         float total_grad = 0.0;
101         float total_hess = 0.0;
102         for (int bucket = 0; bucket < num_buckets; ++bucket) {
103           // TODO(nponomareva): Consider multi-dimensional gradients/hessians.
104           total_grad += stats_summary[feature_idx](node_id, bucket, 0);
105           total_hess += stats_summary[feature_idx](node_id, bucket, 1);
106           cum_grad.push_back(total_grad);
107           cum_hess.push_back(total_hess);
108         }
109         // Check if node has enough of average hessian.
110         if (total_hess < min_node_weight) {
111           // Do not split the node because not enough avg hessian.
112           continue;
113         }
114         float best_gain = std::numeric_limits<float>::lowest();
115         float best_bucket = 0;
116         float best_contrib_for_left = 0.0;
117         float best_contrib_for_right = 0.0;
118         // Parent gain.
119         float parent_gain;
120         float unused;
121         CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused,
122                                  &parent_gain);
123 
124         for (int bucket = 0; bucket < num_buckets; ++bucket) {
125           const float cum_grad_bucket = cum_grad[bucket];
126           const float cum_hess_bucket = cum_hess[bucket];
127           // Left child.
128           float contrib_for_left;
129           float gain_for_left;
130           CalculateWeightsAndGains(cum_grad_bucket, cum_hess_bucket, l1, l2,
131                                    &contrib_for_left, &gain_for_left);
132           // Right child.
133           float contrib_for_right;
134           float gain_for_right;
135           CalculateWeightsAndGains(total_grad - cum_grad_bucket,
136                                    total_hess - cum_hess_bucket, l1, l2,
137                                    &contrib_for_right, &gain_for_right);
138 
139           if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
140             best_gain = gain_for_left + gain_for_right;
141             best_bucket = bucket;
142             best_contrib_for_left = contrib_for_left;
143             best_contrib_for_right = contrib_for_right;
144           }
145         }  // for bucket
146         output_node_ids.push_back(node_id);
147         // Remove the parent gain for the parent node.
148         output_gains.push_back(best_gain - parent_gain);
149         output_thresholds.push_back(best_bucket);
150         output_left_node_contribs.push_back(best_contrib_for_left);
151         output_right_node_contribs.push_back(best_contrib_for_right);
152       }  // for node_id
153       const int num_nodes = output_node_ids.size();
154       // output_node_ids
155       Tensor* output_node_ids_t;
156       OP_REQUIRES_OK(context,
157                      output_node_ids_list.allocate(feature_idx, {num_nodes},
158                                                    &output_node_ids_t));
159       auto output_node_ids_vec = output_node_ids_t->vec<int32>();
160       // output_gains
161       Tensor* output_gains_t;
162       OP_REQUIRES_OK(context, output_gains_list.allocate(
163                                   feature_idx, {num_nodes}, &output_gains_t));
164       auto output_gains_vec = output_gains_t->vec<float>();
165       // output_thresholds
166       Tensor* output_thresholds_t;
167       OP_REQUIRES_OK(context,
168                      output_thresholds_list.allocate(feature_idx, {num_nodes},
169                                                      &output_thresholds_t));
170       auto output_thresholds_vec = output_thresholds_t->vec<int32>();
171       // output_left_node_contribs
172       Tensor* output_left_node_contribs_t;
173       OP_REQUIRES_OK(context, output_left_node_contribs_list.allocate(
174                                   feature_idx, {num_nodes, 1},
175                                   &output_left_node_contribs_t));
176       auto output_left_node_contribs_matrix =
177           output_left_node_contribs_t->matrix<float>();
178       // output_right_node_contribs
179       Tensor* output_right_node_contribs_t;
180       OP_REQUIRES_OK(context, output_right_node_contribs_list.allocate(
181                                   feature_idx, {num_nodes, 1},
182                                   &output_right_node_contribs_t));
183       auto output_right_node_contribs_matrix =
184           output_right_node_contribs_t->matrix<float>();
185       // Sets output tensors from vectors.
186       for (int i = 0; i < num_nodes; ++i) {
187         output_node_ids_vec(i) = output_node_ids[i];
188         // Adjust the gains to penalize by tree complexity.
189         output_gains_vec(i) = output_gains[i] - tree_complexity;
190         output_thresholds_vec(i) = output_thresholds[i];
191         // Logits are 1-dimensional for now.
192         // TODO(nponomareva): Consider multi-dimensional logits.
193         output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
194         output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
195       }
196     }  // for f
197   }
198 
199  private:
200   int max_splits_;
201   int num_features_;
202 };
203 
204 REGISTER_KERNEL_BUILDER(
205     Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU),
206     BoostedTreesCalculateBestGainsPerFeatureOp);
207 
208 class BoostedTreesMakeStatsSummaryOp : public OpKernel {
209  public:
BoostedTreesMakeStatsSummaryOp(OpKernelConstruction * const context)210   explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context)
211       : OpKernel(context) {
212     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
213     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
214     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
215   }
216 
Compute(OpKernelContext * const context)217   void Compute(OpKernelContext* const context) override {
218     // node_ids
219     const Tensor* node_ids_t;
220     OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
221     const auto node_ids = node_ids_t->vec<int32>();
222     // gradients
223     const Tensor* gradients_t;
224     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
225     const auto gradients = gradients_t->matrix<float>();
226     // hessians
227     const Tensor* hessians_t;
228     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
229     const auto hessians = hessians_t->matrix<float>();
230     // bucketized_features
231     OpInputList bucketized_features_list;
232     OP_REQUIRES_OK(context, context->input_list("bucketized_features_list",
233                                                 &bucketized_features_list));
234     // Infer batch size.
235     const int64 batch_size = node_ids_t->dim_size(0);
236 
237     // Allocate temporary stats tensor (Rank 4).
238     Tensor temp_stats_double_t;
239     OP_REQUIRES_OK(context, context->allocate_temp(
240                                 DT_DOUBLE,
241                                 {num_features_, max_splits_, num_buckets_, 2},
242                                 &temp_stats_double_t));
243     auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
244     temp_stats_double.setZero();
245 
246     // Partition by node, and then bucketize.
247     for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
248       const auto& features = bucketized_features_list[feature_idx].vec<int32>();
249       for (int i = 0; i < batch_size; ++i) {
250         const int32 node = node_ids(i);
251         const int32 bucket = features(i);
252         temp_stats_double(feature_idx, node, bucket, 0) += gradients(i, 0);
253         temp_stats_double(feature_idx, node, bucket, 1) += hessians(i, 0);
254       }
255     }
256 
257     // Copy temp tensor over to output tensor.
258     Tensor* output_stats_summary_t = nullptr;
259     OP_REQUIRES_OK(context, context->allocate_output(
260                                 "stats_summary", temp_stats_double_t.shape(),
261                                 &output_stats_summary_t));
262     output_stats_summary_t->tensor<float, 4>() =
263         temp_stats_double.template cast<float>();
264   }
265 
266  private:
267   int max_splits_;
268   int num_buckets_;
269   int num_features_;
270 };
271 
272 REGISTER_KERNEL_BUILDER(Name("BoostedTreesMakeStatsSummary").Device(DEVICE_CPU),
273                         BoostedTreesMakeStatsSummaryOp);
274 
275 }  // namespace tensorflow
276