1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <algorithm> 17 #include <string> 18 #include <vector> 19 20 #include "tensorflow/core/framework/device_base.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/resource_mgr.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor_shape.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h" 27 #include "tensorflow/core/kernels/boosted_trees/resources.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/refcount.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/lib/core/threadpool.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/protobuf.h" 34 #include "tensorflow/core/platform/types.h" 35 #include "tensorflow/core/util/work_sharder.h" 36 37 namespace tensorflow { 38 39 // The Op used during training time to get the predictions so far with the 40 // current ensemble being built. 41 // Expect some logits are cached from the previous step and passed through 42 // to be reused. 43 class BoostedTreesTrainingPredictOp : public OpKernel { 44 public: BoostedTreesTrainingPredictOp(OpKernelConstruction * const context)45 explicit BoostedTreesTrainingPredictOp(OpKernelConstruction* const context) 46 : OpKernel(context) { 47 OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features", 48 &num_bucketized_features_)); 49 OP_REQUIRES_OK(context, 50 context->GetAttr("logits_dimension", &logits_dimension_)); 51 OP_REQUIRES(context, logits_dimension_ == 1, 52 errors::InvalidArgument( 53 "Currently only one dimensional outputs are supported.")); 54 } 55 Compute(OpKernelContext * const context)56 void Compute(OpKernelContext* const context) override { 57 BoostedTreesEnsembleResource* resource; 58 // Get the resource. 59 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 60 &resource)); 61 // Release the reference to the resource once we're done using it. 62 core::ScopedUnref unref_me(resource); 63 64 // Get the inputs. 65 OpInputList bucketized_features_list; 66 OP_REQUIRES_OK(context, context->input_list("bucketized_features", 67 &bucketized_features_list)); 68 std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features; 69 batch_bucketized_features.reserve(bucketized_features_list.size()); 70 for (const Tensor& tensor : bucketized_features_list) { 71 batch_bucketized_features.emplace_back(tensor.vec<int32>()); 72 } 73 const int batch_size = batch_bucketized_features[0].size(); 74 75 const Tensor* cached_tree_ids_t; 76 OP_REQUIRES_OK(context, 77 context->input("cached_tree_ids", &cached_tree_ids_t)); 78 const auto cached_tree_ids = cached_tree_ids_t->vec<int32>(); 79 80 const Tensor* cached_node_ids_t; 81 OP_REQUIRES_OK(context, 82 context->input("cached_node_ids", &cached_node_ids_t)); 83 const auto cached_node_ids = cached_node_ids_t->vec<int32>(); 84 85 // Allocate outputs. 86 Tensor* output_partial_logits_t = nullptr; 87 OP_REQUIRES_OK(context, 88 context->allocate_output("partial_logits", 89 {batch_size, logits_dimension_}, 90 &output_partial_logits_t)); 91 auto output_partial_logits = output_partial_logits_t->matrix<float>(); 92 93 Tensor* output_tree_ids_t = nullptr; 94 OP_REQUIRES_OK(context, context->allocate_output("tree_ids", {batch_size}, 95 &output_tree_ids_t)); 96 auto output_tree_ids = output_tree_ids_t->vec<int32>(); 97 98 Tensor* output_node_ids_t = nullptr; 99 OP_REQUIRES_OK(context, context->allocate_output("node_ids", {batch_size}, 100 &output_node_ids_t)); 101 auto output_node_ids = output_node_ids_t->vec<int32>(); 102 103 // Indicate that the latest tree was used. 104 const int32 latest_tree = resource->num_trees() - 1; 105 106 if (latest_tree < 0) { 107 // Ensemble was empty. Output the very first node. 108 output_node_ids.setZero(); 109 output_tree_ids = cached_tree_ids; 110 // All the predictions are zeros. 111 output_partial_logits.setZero(); 112 } else { 113 output_tree_ids.setConstant(latest_tree); 114 auto do_work = [&resource, &batch_bucketized_features, &cached_tree_ids, 115 &cached_node_ids, &output_partial_logits, 116 &output_node_ids, latest_tree](int32 start, int32 end) { 117 for (int32 i = start; i < end; ++i) { 118 int32 tree_id = cached_tree_ids(i); 119 int32 node_id = cached_node_ids(i); 120 float partial_tree_logit = 0.0; 121 122 if (node_id >= 0) { 123 // If the tree was pruned, returns the node id into which the 124 // current_node_id was pruned, as well the correction of the cached 125 // logit prediction. 126 resource->GetPostPruneCorrection(tree_id, node_id, &node_id, 127 &partial_tree_logit); 128 // Logic in the loop adds the cached node value again if it is a 129 // leaf. If it is not a leaf anymore we need to subtract the old 130 // node's value. The following logic handles both of these cases. 131 const auto& node_logits = resource->node_value(tree_id, node_id); 132 DCHECK_EQ(node_logits.size(), 1); 133 partial_tree_logit -= node_logits[0]; 134 } else { 135 // No cache exists, start from the very first node. 136 node_id = 0; 137 } 138 float partial_all_logit = 0.0; 139 while (true) { 140 if (resource->is_leaf(tree_id, node_id)) { 141 const auto& leaf_logits = resource->node_value(tree_id, node_id); 142 DCHECK_EQ(leaf_logits.size(), 1); 143 partial_tree_logit += leaf_logits[0]; 144 145 // Tree is done 146 partial_all_logit += 147 resource->GetTreeWeight(tree_id) * partial_tree_logit; 148 partial_tree_logit = 0.0; 149 // Stop if it was the latest tree. 150 if (tree_id == latest_tree) { 151 break; 152 } 153 // Move onto other trees. 154 ++tree_id; 155 node_id = 0; 156 } else { 157 node_id = resource->next_node(tree_id, node_id, i, 158 batch_bucketized_features); 159 } 160 } 161 output_node_ids(i) = node_id; 162 output_partial_logits(i, 0) = partial_all_logit; 163 } 164 }; 165 // 30 is the magic number. The actual value might be a function of (the 166 // number of layers) * (cpu cycles spent on each layer), but this value 167 // would work for many cases. May be tuned later. 168 const int64 cost = 30; 169 thread::ThreadPool* const worker_threads = 170 context->device()->tensorflow_cpu_worker_threads()->workers; 171 Shard(worker_threads->NumThreads(), worker_threads, batch_size, 172 /*cost_per_unit=*/cost, do_work); 173 } 174 } 175 176 private: 177 int32 logits_dimension_; // the size of the output prediction vector. 178 int32 num_bucketized_features_; // Indicates the number of features. 179 }; 180 181 REGISTER_KERNEL_BUILDER(Name("BoostedTreesTrainingPredict").Device(DEVICE_CPU), 182 BoostedTreesTrainingPredictOp); 183 184 // The Op to get the predictions at the evaluation/inference time. 185 class BoostedTreesPredictOp : public OpKernel { 186 public: BoostedTreesPredictOp(OpKernelConstruction * const context)187 explicit BoostedTreesPredictOp(OpKernelConstruction* const context) 188 : OpKernel(context) { 189 OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features", 190 &num_bucketized_features_)); 191 OP_REQUIRES_OK(context, 192 context->GetAttr("logits_dimension", &logits_dimension_)); 193 } 194 Compute(OpKernelContext * const context)195 void Compute(OpKernelContext* const context) override { 196 BoostedTreesEnsembleResource* resource; 197 // Get the resource. 198 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 199 &resource)); 200 // Release the reference to the resource once we're done using it. 201 core::ScopedUnref unref_me(resource); 202 203 // Get the inputs. 204 OpInputList bucketized_features_list; 205 OP_REQUIRES_OK(context, context->input_list("bucketized_features", 206 &bucketized_features_list)); 207 std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features; 208 batch_bucketized_features.reserve(bucketized_features_list.size()); 209 for (const Tensor& tensor : bucketized_features_list) { 210 batch_bucketized_features.emplace_back(tensor.vec<int32>()); 211 } 212 const int batch_size = batch_bucketized_features[0].size(); 213 214 // Allocate outputs. 215 Tensor* output_logits_t = nullptr; 216 OP_REQUIRES_OK(context, context->allocate_output( 217 "logits", {batch_size, logits_dimension_}, 218 &output_logits_t)); 219 auto output_logits = output_logits_t->matrix<float>(); 220 221 // Return zero logits if it's an empty ensemble. 222 if (resource->num_trees() <= 0) { 223 output_logits.setZero(); 224 return; 225 } 226 227 const int32 last_tree = resource->num_trees() - 1; 228 auto do_work = [&resource, &batch_bucketized_features, &output_logits, 229 last_tree, this](int32 start, int32 end) { 230 for (int32 i = start; i < end; ++i) { 231 std::vector<float> tree_logits(logits_dimension_, 0.0); 232 int32 tree_id = 0; 233 int32 node_id = 0; 234 while (true) { 235 if (resource->is_leaf(tree_id, node_id)) { 236 const float tree_weight = resource->GetTreeWeight(tree_id); 237 const auto& leaf_logits = resource->node_value(tree_id, node_id); 238 DCHECK_EQ(leaf_logits.size(), logits_dimension_); 239 for (int32 j = 0; j < logits_dimension_; ++j) { 240 tree_logits[j] += tree_weight * leaf_logits[j]; 241 } 242 // Stop if it was the last tree. 243 if (tree_id == last_tree) { 244 break; 245 } 246 // Move onto other trees. 247 ++tree_id; 248 node_id = 0; 249 } else { 250 node_id = resource->next_node(tree_id, node_id, i, 251 batch_bucketized_features); 252 } 253 } 254 for (int32 j = 0; j < logits_dimension_; ++j) { 255 output_logits(i, j) = tree_logits[j]; 256 } 257 } 258 }; 259 // 10 is the magic number. The actual number might depend on (the number of 260 // layers in the trees) and (cpu cycles spent on each layer), but this 261 // value would work for many cases. May be tuned later. 262 const int64 cost = (last_tree + 1) * 10; 263 thread::ThreadPool* const worker_threads = 264 context->device()->tensorflow_cpu_worker_threads()->workers; 265 Shard(worker_threads->NumThreads(), worker_threads, batch_size, 266 /*cost_per_unit=*/cost, do_work); 267 } 268 269 private: 270 int32 271 logits_dimension_; // Indicates the size of the output prediction vector. 272 int32 num_bucketized_features_; // Indicates the number of features. 273 }; 274 275 REGISTER_KERNEL_BUILDER(Name("BoostedTreesPredict").Device(DEVICE_CPU), 276 BoostedTreesPredictOp); 277 278 // The Op that returns debugging/model interpretability outputs for each 279 // example. Currently it outputs the split feature ids and logits after each 280 // split along the decision path for each example. This will be used to compute 281 // directional feature contributions at predict time for an arbitrary activation 282 // function. 283 // TODO(crawles): return in proto 1) Node IDs for ensemble prediction path 284 // 2) Leaf node IDs. 285 class BoostedTreesExampleDebugOutputsOp : public OpKernel { 286 public: BoostedTreesExampleDebugOutputsOp(OpKernelConstruction * const context)287 explicit BoostedTreesExampleDebugOutputsOp( 288 OpKernelConstruction* const context) 289 : OpKernel(context) { 290 OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features", 291 &num_bucketized_features_)); 292 OP_REQUIRES_OK(context, 293 context->GetAttr("logits_dimension", &logits_dimension_)); 294 OP_REQUIRES(context, logits_dimension_ == 1, 295 errors::InvalidArgument( 296 "Currently only one dimensional outputs are supported.")); 297 } 298 Compute(OpKernelContext * const context)299 void Compute(OpKernelContext* const context) override { 300 BoostedTreesEnsembleResource* resource; 301 // Get the resource. 302 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 303 &resource)); 304 // Release the reference to the resource once we're done using it. 305 core::ScopedUnref unref_me(resource); 306 307 // Get the inputs. 308 OpInputList bucketized_features_list; 309 OP_REQUIRES_OK(context, context->input_list("bucketized_features", 310 &bucketized_features_list)); 311 std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features; 312 batch_bucketized_features.reserve(bucketized_features_list.size()); 313 for (const Tensor& tensor : bucketized_features_list) { 314 batch_bucketized_features.emplace_back(tensor.vec<int32>()); 315 } 316 const int batch_size = batch_bucketized_features[0].size(); 317 318 // We need to get the feature ids used for splitting and the logits after 319 // each split. We will use these to calculate the changes in the prediction 320 // (contributions) for an arbitrary activation function (done in Python) and 321 // attribute them to the associated feature ids. We will store these in 322 // a proto below. 323 Tensor* output_debug_info_t = nullptr; 324 OP_REQUIRES_OK( 325 context, context->allocate_output("examples_debug_outputs_serialized", 326 {batch_size}, &output_debug_info_t)); 327 // Will contain serialized protos, per example. 328 auto output_debug_info = output_debug_info_t->flat<string>(); 329 const int32 last_tree = resource->num_trees() - 1; 330 331 // For each given example, traverse through all trees keeping track of the 332 // features used to split and the associated logits at each point along the 333 // path. Note: feature_ids has one less value than logits_path because the 334 // first value of each logit path will be the bias. 335 auto do_work = [&resource, &batch_bucketized_features, &output_debug_info, 336 last_tree](int32 start, int32 end) { 337 for (int32 i = start; i < end; ++i) { 338 // Proto to store debug outputs, per example. 339 boosted_trees::DebugOutput example_debug_info; 340 // Initial bias prediction. E.g., prediction based off training mean. 341 const auto& tree_logits = resource->node_value(0, 0); 342 DCHECK_EQ(tree_logits.size(), 1); 343 float tree_logit = resource->GetTreeWeight(0) * tree_logits[0]; 344 example_debug_info.add_logits_path(tree_logit); 345 int32 node_id = 0; 346 int32 tree_id = 0; 347 int32 feature_id; 348 float past_trees_logit = 0; // Sum of leaf logits from prior trees. 349 // Go through each tree and populate proto. 350 while (tree_id <= last_tree) { 351 if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees. 352 // Accumulate tree_logits only if the leaf is non-root, but do so 353 // for bias tree. 354 if (tree_id == 0 || node_id > 0) { 355 past_trees_logit += tree_logit; 356 } 357 ++tree_id; 358 node_id = 0; 359 } else { // Add to proto. 360 // Feature id used to split. 361 feature_id = resource->feature_id(tree_id, node_id); 362 example_debug_info.add_feature_ids(feature_id); 363 // Get logit after split. 364 node_id = resource->next_node(tree_id, node_id, i, 365 batch_bucketized_features); 366 const auto& tree_logits = resource->node_value(tree_id, node_id); 367 DCHECK_EQ(tree_logits.size(), 1); 368 tree_logit = resource->GetTreeWeight(tree_id) * tree_logits[0]; 369 // Output logit incorporates sum of leaf logits from prior trees. 370 example_debug_info.add_logits_path(tree_logit + past_trees_logit); 371 } 372 } 373 // Set output as serialized proto containing debug info. 374 string serialized = example_debug_info.SerializeAsString(); 375 output_debug_info(i) = serialized; 376 } 377 }; 378 379 // 10 is the magic number. The actual number might depend on (the number of 380 // layers in the trees) and (cpu cycles spent on each layer), but this 381 // value would work for many cases. May be tuned later. 382 const int64 cost = (last_tree + 1) * 10; 383 thread::ThreadPool* const worker_threads = 384 context->device()->tensorflow_cpu_worker_threads()->workers; 385 Shard(worker_threads->NumThreads(), worker_threads, batch_size, 386 /*cost_per_unit=*/cost, do_work); 387 } 388 389 private: 390 int32 logits_dimension_; // Indicates dimension of logits in the tree nodes. 391 int32 num_bucketized_features_; // Indicates the number of features. 392 }; 393 394 REGISTER_KERNEL_BUILDER( 395 Name("BoostedTreesExampleDebugOutputs").Device(DEVICE_CPU), 396 BoostedTreesExampleDebugOutputsOp); 397 398 } // namespace tensorflow 399