1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <vector> 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/tensor.h" 20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h" 21 22 namespace tensorflow { 23 24 class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel { 25 public: BoostedTreesCalculateBestGainsPerFeatureOp(OpKernelConstruction * const context)26 explicit BoostedTreesCalculateBestGainsPerFeatureOp( 27 OpKernelConstruction* const context) 28 : OpKernel(context) { 29 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_)); 30 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_)); 31 } 32 Compute(OpKernelContext * const context)33 void Compute(OpKernelContext* const context) override { 34 // node_id_range 35 const Tensor* node_id_range_t; 36 OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t)); 37 const auto node_id_range = node_id_range_t->vec<int32>(); 38 const int32 node_id_first = node_id_range(0); // inclusive 39 const int32 node_id_last = node_id_range(1); // exclusive 40 // stats_summary_list 41 OpInputList stats_summary_list; 42 OP_REQUIRES_OK(context, context->input_list("stats_summary_list", 43 &stats_summary_list)); 44 const int64 num_buckets = stats_summary_list[0].dim_size(1); 45 std::vector<TTypes<float, 3>::ConstTensor> stats_summary; 46 stats_summary.reserve(stats_summary_list.size()); 47 for (const auto& tensor : stats_summary_list) { 48 stats_summary.emplace_back(tensor.tensor<float, 3>()); 49 } 50 const Tensor* l1_t; 51 OP_REQUIRES_OK(context, context->input("l1", &l1_t)); 52 const auto l1 = l1_t->scalar<float>()(); 53 const Tensor* l2_t; 54 OP_REQUIRES_OK(context, context->input("l2", &l2_t)); 55 const auto l2 = l2_t->scalar<float>()(); 56 const Tensor* tree_complexity_t; 57 OP_REQUIRES_OK(context, 58 context->input("tree_complexity", &tree_complexity_t)); 59 const auto tree_complexity = tree_complexity_t->scalar<float>()(); 60 const Tensor* min_node_weight_t; 61 OP_REQUIRES_OK(context, 62 context->input("min_node_weight", &min_node_weight_t)); 63 const auto min_node_weight = min_node_weight_t->scalar<float>()(); 64 65 // Allocate output lists of tensors: 66 OpOutputList output_node_ids_list; 67 OP_REQUIRES_OK( 68 context, context->output_list("node_ids_list", &output_node_ids_list)); 69 OpOutputList output_gains_list; 70 OP_REQUIRES_OK(context, 71 context->output_list("gains_list", &output_gains_list)); 72 OpOutputList output_thresholds_list; 73 OP_REQUIRES_OK(context, context->output_list("thresholds_list", 74 &output_thresholds_list)); 75 OpOutputList output_left_node_contribs_list; 76 OP_REQUIRES_OK(context, 77 context->output_list("left_node_contribs_list", 78 &output_left_node_contribs_list)); 79 OpOutputList output_right_node_contribs_list; 80 OP_REQUIRES_OK(context, 81 context->output_list("right_node_contribs_list", 82 &output_right_node_contribs_list)); 83 84 // Get the best split info per node for each feature. 85 for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) { 86 std::vector<float> cum_grad; 87 std::vector<float> cum_hess; 88 cum_grad.reserve(num_buckets); 89 cum_hess.reserve(num_buckets); 90 91 std::vector<int32> output_node_ids; 92 std::vector<float> output_gains; 93 std::vector<int32> output_thresholds; 94 std::vector<float> output_left_node_contribs; 95 std::vector<float> output_right_node_contribs; 96 for (int node_id = node_id_first; node_id < node_id_last; ++node_id) { 97 // Calculate gains. 98 cum_grad.clear(); 99 cum_hess.clear(); 100 float total_grad = 0.0; 101 float total_hess = 0.0; 102 for (int bucket = 0; bucket < num_buckets; ++bucket) { 103 // TODO(nponomareva): Consider multi-dimensional gradients/hessians. 104 total_grad += stats_summary[feature_idx](node_id, bucket, 0); 105 total_hess += stats_summary[feature_idx](node_id, bucket, 1); 106 cum_grad.push_back(total_grad); 107 cum_hess.push_back(total_hess); 108 } 109 // Check if node has enough of average hessian. 110 if (total_hess < min_node_weight) { 111 // Do not split the node because not enough avg hessian. 112 continue; 113 } 114 float best_gain = std::numeric_limits<float>::lowest(); 115 float best_bucket = 0; 116 float best_contrib_for_left = 0.0; 117 float best_contrib_for_right = 0.0; 118 // Parent gain. 119 float parent_gain; 120 float unused; 121 CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused, 122 &parent_gain); 123 124 for (int bucket = 0; bucket < num_buckets; ++bucket) { 125 const float cum_grad_bucket = cum_grad[bucket]; 126 const float cum_hess_bucket = cum_hess[bucket]; 127 // Left child. 128 float contrib_for_left; 129 float gain_for_left; 130 CalculateWeightsAndGains(cum_grad_bucket, cum_hess_bucket, l1, l2, 131 &contrib_for_left, &gain_for_left); 132 // Right child. 133 float contrib_for_right; 134 float gain_for_right; 135 CalculateWeightsAndGains(total_grad - cum_grad_bucket, 136 total_hess - cum_hess_bucket, l1, l2, 137 &contrib_for_right, &gain_for_right); 138 139 if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) { 140 best_gain = gain_for_left + gain_for_right; 141 best_bucket = bucket; 142 best_contrib_for_left = contrib_for_left; 143 best_contrib_for_right = contrib_for_right; 144 } 145 } // for bucket 146 output_node_ids.push_back(node_id); 147 // Remove the parent gain for the parent node. 148 output_gains.push_back(best_gain - parent_gain); 149 output_thresholds.push_back(best_bucket); 150 output_left_node_contribs.push_back(best_contrib_for_left); 151 output_right_node_contribs.push_back(best_contrib_for_right); 152 } // for node_id 153 const int num_nodes = output_node_ids.size(); 154 // output_node_ids 155 Tensor* output_node_ids_t; 156 OP_REQUIRES_OK(context, 157 output_node_ids_list.allocate(feature_idx, {num_nodes}, 158 &output_node_ids_t)); 159 auto output_node_ids_vec = output_node_ids_t->vec<int32>(); 160 // output_gains 161 Tensor* output_gains_t; 162 OP_REQUIRES_OK(context, output_gains_list.allocate( 163 feature_idx, {num_nodes}, &output_gains_t)); 164 auto output_gains_vec = output_gains_t->vec<float>(); 165 // output_thresholds 166 Tensor* output_thresholds_t; 167 OP_REQUIRES_OK(context, 168 output_thresholds_list.allocate(feature_idx, {num_nodes}, 169 &output_thresholds_t)); 170 auto output_thresholds_vec = output_thresholds_t->vec<int32>(); 171 // output_left_node_contribs 172 Tensor* output_left_node_contribs_t; 173 OP_REQUIRES_OK(context, output_left_node_contribs_list.allocate( 174 feature_idx, {num_nodes, 1}, 175 &output_left_node_contribs_t)); 176 auto output_left_node_contribs_matrix = 177 output_left_node_contribs_t->matrix<float>(); 178 // output_right_node_contribs 179 Tensor* output_right_node_contribs_t; 180 OP_REQUIRES_OK(context, output_right_node_contribs_list.allocate( 181 feature_idx, {num_nodes, 1}, 182 &output_right_node_contribs_t)); 183 auto output_right_node_contribs_matrix = 184 output_right_node_contribs_t->matrix<float>(); 185 // Sets output tensors from vectors. 186 for (int i = 0; i < num_nodes; ++i) { 187 output_node_ids_vec(i) = output_node_ids[i]; 188 // Adjust the gains to penalize by tree complexity. 189 output_gains_vec(i) = output_gains[i] - tree_complexity; 190 output_thresholds_vec(i) = output_thresholds[i]; 191 // Logits are 1-dimensional for now. 192 // TODO(nponomareva): Consider multi-dimensional logits. 193 output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i]; 194 output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i]; 195 } 196 } // for f 197 } 198 199 private: 200 int max_splits_; 201 int num_features_; 202 }; 203 204 REGISTER_KERNEL_BUILDER( 205 Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU), 206 BoostedTreesCalculateBestGainsPerFeatureOp); 207 208 class BoostedTreesMakeStatsSummaryOp : public OpKernel { 209 public: BoostedTreesMakeStatsSummaryOp(OpKernelConstruction * const context)210 explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context) 211 : OpKernel(context) { 212 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_)); 213 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_)); 214 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_)); 215 } 216 Compute(OpKernelContext * const context)217 void Compute(OpKernelContext* const context) override { 218 // node_ids 219 const Tensor* node_ids_t; 220 OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t)); 221 const auto node_ids = node_ids_t->vec<int32>(); 222 // gradients 223 const Tensor* gradients_t; 224 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); 225 const auto gradients = gradients_t->matrix<float>(); 226 // hessians 227 const Tensor* hessians_t; 228 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); 229 const auto hessians = hessians_t->matrix<float>(); 230 // bucketized_features 231 OpInputList bucketized_features_list; 232 OP_REQUIRES_OK(context, context->input_list("bucketized_features_list", 233 &bucketized_features_list)); 234 // Infer batch size. 235 const int64 batch_size = node_ids_t->dim_size(0); 236 237 // Allocate temporary stats tensor (Rank 4). 238 Tensor temp_stats_double_t; 239 OP_REQUIRES_OK(context, context->allocate_temp( 240 DT_DOUBLE, 241 {num_features_, max_splits_, num_buckets_, 2}, 242 &temp_stats_double_t)); 243 auto temp_stats_double = temp_stats_double_t.tensor<double, 4>(); 244 temp_stats_double.setZero(); 245 246 // Partition by node, and then bucketize. 247 for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) { 248 const auto& features = bucketized_features_list[feature_idx].vec<int32>(); 249 for (int i = 0; i < batch_size; ++i) { 250 const int32 node = node_ids(i); 251 const int32 bucket = features(i); 252 temp_stats_double(feature_idx, node, bucket, 0) += gradients(i, 0); 253 temp_stats_double(feature_idx, node, bucket, 1) += hessians(i, 0); 254 } 255 } 256 257 // Copy temp tensor over to output tensor. 258 Tensor* output_stats_summary_t = nullptr; 259 OP_REQUIRES_OK(context, context->allocate_output( 260 "stats_summary", temp_stats_double_t.shape(), 261 &output_stats_summary_t)); 262 output_stats_summary_t->tensor<float, 4>() = 263 temp_stats_double.template cast<float>(); 264 } 265 266 private: 267 int max_splits_; 268 int num_buckets_; 269 int num_features_; 270 }; 271 272 REGISTER_KERNEL_BUILDER(Name("BoostedTreesMakeStatsSummary").Device(DEVICE_CPU), 273 BoostedTreesMakeStatsSummaryOp); 274 275 } // namespace tensorflow 276