1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/device_base.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
27 #include "tensorflow/core/kernels/boosted_trees/resources.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/refcount.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 
37 namespace tensorflow {
38 
39 // The Op used during training time to get the predictions so far with the
40 // current ensemble being built.
41 // Expect some logits are cached from the previous step and passed through
42 // to be reused.
43 class BoostedTreesTrainingPredictOp : public OpKernel {
44  public:
BoostedTreesTrainingPredictOp(OpKernelConstruction * const context)45   explicit BoostedTreesTrainingPredictOp(OpKernelConstruction* const context)
46       : OpKernel(context) {
47     OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
48                                              &num_bucketized_features_));
49     OP_REQUIRES_OK(context,
50                    context->GetAttr("logits_dimension", &logits_dimension_));
51     OP_REQUIRES(context, logits_dimension_ == 1,
52                 errors::InvalidArgument(
53                     "Currently only one dimensional outputs are supported."));
54   }
55 
Compute(OpKernelContext * const context)56   void Compute(OpKernelContext* const context) override {
57     BoostedTreesEnsembleResource* resource;
58     // Get the resource.
59     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
60                                            &resource));
61     // Release the reference to the resource once we're done using it.
62     core::ScopedUnref unref_me(resource);
63 
64     // Get the inputs.
65     OpInputList bucketized_features_list;
66     OP_REQUIRES_OK(context, context->input_list("bucketized_features",
67                                                 &bucketized_features_list));
68     std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
69     batch_bucketized_features.reserve(bucketized_features_list.size());
70     for (const Tensor& tensor : bucketized_features_list) {
71       batch_bucketized_features.emplace_back(tensor.vec<int32>());
72     }
73     const int batch_size = batch_bucketized_features[0].size();
74 
75     const Tensor* cached_tree_ids_t;
76     OP_REQUIRES_OK(context,
77                    context->input("cached_tree_ids", &cached_tree_ids_t));
78     const auto cached_tree_ids = cached_tree_ids_t->vec<int32>();
79 
80     const Tensor* cached_node_ids_t;
81     OP_REQUIRES_OK(context,
82                    context->input("cached_node_ids", &cached_node_ids_t));
83     const auto cached_node_ids = cached_node_ids_t->vec<int32>();
84 
85     // Allocate outputs.
86     Tensor* output_partial_logits_t = nullptr;
87     OP_REQUIRES_OK(context,
88                    context->allocate_output("partial_logits",
89                                             {batch_size, logits_dimension_},
90                                             &output_partial_logits_t));
91     auto output_partial_logits = output_partial_logits_t->matrix<float>();
92 
93     Tensor* output_tree_ids_t = nullptr;
94     OP_REQUIRES_OK(context, context->allocate_output("tree_ids", {batch_size},
95                                                      &output_tree_ids_t));
96     auto output_tree_ids = output_tree_ids_t->vec<int32>();
97 
98     Tensor* output_node_ids_t = nullptr;
99     OP_REQUIRES_OK(context, context->allocate_output("node_ids", {batch_size},
100                                                      &output_node_ids_t));
101     auto output_node_ids = output_node_ids_t->vec<int32>();
102 
103     // Indicate that the latest tree was used.
104     const int32 latest_tree = resource->num_trees() - 1;
105 
106     if (latest_tree < 0) {
107       // Ensemble was empty. Output the very first node.
108       output_node_ids.setZero();
109       output_tree_ids = cached_tree_ids;
110       // All the predictions are zeros.
111       output_partial_logits.setZero();
112     } else {
113       output_tree_ids.setConstant(latest_tree);
114       auto do_work = [&resource, &batch_bucketized_features, &cached_tree_ids,
115                       &cached_node_ids, &output_partial_logits,
116                       &output_node_ids, latest_tree](int32 start, int32 end) {
117         for (int32 i = start; i < end; ++i) {
118           int32 tree_id = cached_tree_ids(i);
119           int32 node_id = cached_node_ids(i);
120           float partial_tree_logit = 0.0;
121 
122           if (node_id >= 0) {
123             // If the tree was pruned, returns the node id into which the
124             // current_node_id was pruned, as well the correction of the cached
125             // logit prediction.
126             resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
127                                              &partial_tree_logit);
128             // Logic in the loop adds the cached node value again if it is a
129             // leaf. If it is not a leaf anymore we need to subtract the old
130             // node's value. The following logic handles both of these cases.
131             const auto& node_logits = resource->node_value(tree_id, node_id);
132             DCHECK_EQ(node_logits.size(), 1);
133             partial_tree_logit -= node_logits[0];
134           } else {
135             // No cache exists, start from the very first node.
136             node_id = 0;
137           }
138           float partial_all_logit = 0.0;
139           while (true) {
140             if (resource->is_leaf(tree_id, node_id)) {
141               const auto& leaf_logits = resource->node_value(tree_id, node_id);
142               DCHECK_EQ(leaf_logits.size(), 1);
143               partial_tree_logit += leaf_logits[0];
144 
145               // Tree is done
146               partial_all_logit +=
147                   resource->GetTreeWeight(tree_id) * partial_tree_logit;
148               partial_tree_logit = 0.0;
149               // Stop if it was the latest tree.
150               if (tree_id == latest_tree) {
151                 break;
152               }
153               // Move onto other trees.
154               ++tree_id;
155               node_id = 0;
156             } else {
157               node_id = resource->next_node(tree_id, node_id, i,
158                                             batch_bucketized_features);
159             }
160           }
161           output_node_ids(i) = node_id;
162           output_partial_logits(i, 0) = partial_all_logit;
163         }
164       };
165       // 30 is the magic number. The actual value might be a function of (the
166       // number of layers) * (cpu cycles spent on each layer), but this value
167       // would work for many cases. May be tuned later.
168       const int64 cost = 30;
169       thread::ThreadPool* const worker_threads =
170           context->device()->tensorflow_cpu_worker_threads()->workers;
171       Shard(worker_threads->NumThreads(), worker_threads, batch_size,
172             /*cost_per_unit=*/cost, do_work);
173     }
174   }
175 
176  private:
177   int32 logits_dimension_;         // the size of the output prediction vector.
178   int32 num_bucketized_features_;  // Indicates the number of features.
179 };
180 
181 REGISTER_KERNEL_BUILDER(Name("BoostedTreesTrainingPredict").Device(DEVICE_CPU),
182                         BoostedTreesTrainingPredictOp);
183 
184 // The Op to get the predictions at the evaluation/inference time.
185 class BoostedTreesPredictOp : public OpKernel {
186  public:
BoostedTreesPredictOp(OpKernelConstruction * const context)187   explicit BoostedTreesPredictOp(OpKernelConstruction* const context)
188       : OpKernel(context) {
189     OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
190                                              &num_bucketized_features_));
191     OP_REQUIRES_OK(context,
192                    context->GetAttr("logits_dimension", &logits_dimension_));
193   }
194 
Compute(OpKernelContext * const context)195   void Compute(OpKernelContext* const context) override {
196     BoostedTreesEnsembleResource* resource;
197     // Get the resource.
198     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
199                                            &resource));
200     // Release the reference to the resource once we're done using it.
201     core::ScopedUnref unref_me(resource);
202 
203     // Get the inputs.
204     OpInputList bucketized_features_list;
205     OP_REQUIRES_OK(context, context->input_list("bucketized_features",
206                                                 &bucketized_features_list));
207     std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
208     batch_bucketized_features.reserve(bucketized_features_list.size());
209     for (const Tensor& tensor : bucketized_features_list) {
210       batch_bucketized_features.emplace_back(tensor.vec<int32>());
211     }
212     const int batch_size = batch_bucketized_features[0].size();
213 
214     // Allocate outputs.
215     Tensor* output_logits_t = nullptr;
216     OP_REQUIRES_OK(context, context->allocate_output(
217                                 "logits", {batch_size, logits_dimension_},
218                                 &output_logits_t));
219     auto output_logits = output_logits_t->matrix<float>();
220 
221     // Return zero logits if it's an empty ensemble.
222     if (resource->num_trees() <= 0) {
223       output_logits.setZero();
224       return;
225     }
226 
227     const int32 last_tree = resource->num_trees() - 1;
228     auto do_work = [&resource, &batch_bucketized_features, &output_logits,
229                     last_tree, this](int32 start, int32 end) {
230       for (int32 i = start; i < end; ++i) {
231         std::vector<float> tree_logits(logits_dimension_, 0.0);
232         int32 tree_id = 0;
233         int32 node_id = 0;
234         while (true) {
235           if (resource->is_leaf(tree_id, node_id)) {
236             const float tree_weight = resource->GetTreeWeight(tree_id);
237             const auto& leaf_logits = resource->node_value(tree_id, node_id);
238             DCHECK_EQ(leaf_logits.size(), logits_dimension_);
239             for (int32 j = 0; j < logits_dimension_; ++j) {
240               tree_logits[j] += tree_weight * leaf_logits[j];
241             }
242             // Stop if it was the last tree.
243             if (tree_id == last_tree) {
244               break;
245             }
246             // Move onto other trees.
247             ++tree_id;
248             node_id = 0;
249           } else {
250             node_id = resource->next_node(tree_id, node_id, i,
251                                           batch_bucketized_features);
252           }
253         }
254         for (int32 j = 0; j < logits_dimension_; ++j) {
255           output_logits(i, j) = tree_logits[j];
256         }
257       }
258     };
259     // 10 is the magic number. The actual number might depend on (the number of
260     // layers in the trees) and (cpu cycles spent on each layer), but this
261     // value would work for many cases. May be tuned later.
262     const int64 cost = (last_tree + 1) * 10;
263     thread::ThreadPool* const worker_threads =
264         context->device()->tensorflow_cpu_worker_threads()->workers;
265     Shard(worker_threads->NumThreads(), worker_threads, batch_size,
266           /*cost_per_unit=*/cost, do_work);
267   }
268 
269  private:
270   int32
271       logits_dimension_;  // Indicates the size of the output prediction vector.
272   int32 num_bucketized_features_;  // Indicates the number of features.
273 };
274 
275 REGISTER_KERNEL_BUILDER(Name("BoostedTreesPredict").Device(DEVICE_CPU),
276                         BoostedTreesPredictOp);
277 
278 // The Op that returns debugging/model interpretability outputs for each
279 // example. Currently it outputs the split feature ids and logits after each
280 // split along the decision path for each example. This will be used to compute
281 // directional feature contributions at predict time for an arbitrary activation
282 // function.
283 // TODO(crawles): return in proto 1) Node IDs for ensemble prediction path
284 // 2) Leaf node IDs.
285 class BoostedTreesExampleDebugOutputsOp : public OpKernel {
286  public:
BoostedTreesExampleDebugOutputsOp(OpKernelConstruction * const context)287   explicit BoostedTreesExampleDebugOutputsOp(
288       OpKernelConstruction* const context)
289       : OpKernel(context) {
290     OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
291                                              &num_bucketized_features_));
292     OP_REQUIRES_OK(context,
293                    context->GetAttr("logits_dimension", &logits_dimension_));
294     OP_REQUIRES(context, logits_dimension_ == 1,
295                 errors::InvalidArgument(
296                     "Currently only one dimensional outputs are supported."));
297   }
298 
Compute(OpKernelContext * const context)299   void Compute(OpKernelContext* const context) override {
300     BoostedTreesEnsembleResource* resource;
301     // Get the resource.
302     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
303                                            &resource));
304     // Release the reference to the resource once we're done using it.
305     core::ScopedUnref unref_me(resource);
306 
307     // Get the inputs.
308     OpInputList bucketized_features_list;
309     OP_REQUIRES_OK(context, context->input_list("bucketized_features",
310                                                 &bucketized_features_list));
311     std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
312     batch_bucketized_features.reserve(bucketized_features_list.size());
313     for (const Tensor& tensor : bucketized_features_list) {
314       batch_bucketized_features.emplace_back(tensor.vec<int32>());
315     }
316     const int batch_size = batch_bucketized_features[0].size();
317 
318     // We need to get the feature ids used for splitting and the logits after
319     // each split. We will use these to calculate the changes in the prediction
320     // (contributions) for an arbitrary activation function (done in Python) and
321     // attribute them to the associated feature ids. We will store these in
322     // a proto below.
323     Tensor* output_debug_info_t = nullptr;
324     OP_REQUIRES_OK(
325         context, context->allocate_output("examples_debug_outputs_serialized",
326                                           {batch_size}, &output_debug_info_t));
327     // Will contain serialized protos, per example.
328     auto output_debug_info = output_debug_info_t->flat<string>();
329     const int32 last_tree = resource->num_trees() - 1;
330 
331     // For each given example, traverse through all trees keeping track of the
332     // features used to split and the associated logits at each point along the
333     // path. Note: feature_ids has one less value than logits_path because the
334     // first value of each logit path will be the bias.
335     auto do_work = [&resource, &batch_bucketized_features, &output_debug_info,
336                     last_tree](int32 start, int32 end) {
337       for (int32 i = start; i < end; ++i) {
338         // Proto to store debug outputs, per example.
339         boosted_trees::DebugOutput example_debug_info;
340         // Initial bias prediction. E.g., prediction based off training mean.
341         const auto& tree_logits = resource->node_value(0, 0);
342         DCHECK_EQ(tree_logits.size(), 1);
343         float tree_logit = resource->GetTreeWeight(0) * tree_logits[0];
344         example_debug_info.add_logits_path(tree_logit);
345         int32 node_id = 0;
346         int32 tree_id = 0;
347         int32 feature_id;
348         float past_trees_logit = 0;  // Sum of leaf logits from prior trees.
349         // Go through each tree and populate proto.
350         while (tree_id <= last_tree) {
351           if (resource->is_leaf(tree_id, node_id)) {  // Move onto other trees.
352             // Accumulate tree_logits only if the leaf is non-root, but do so
353             // for bias tree.
354             if (tree_id == 0 || node_id > 0) {
355               past_trees_logit += tree_logit;
356             }
357             ++tree_id;
358             node_id = 0;
359           } else {  // Add to proto.
360             // Feature id used to split.
361             feature_id = resource->feature_id(tree_id, node_id);
362             example_debug_info.add_feature_ids(feature_id);
363             // Get logit after split.
364             node_id = resource->next_node(tree_id, node_id, i,
365                                           batch_bucketized_features);
366             const auto& tree_logits = resource->node_value(tree_id, node_id);
367             DCHECK_EQ(tree_logits.size(), 1);
368             tree_logit = resource->GetTreeWeight(tree_id) * tree_logits[0];
369             // Output logit incorporates sum of leaf logits from prior trees.
370             example_debug_info.add_logits_path(tree_logit + past_trees_logit);
371           }
372         }
373         // Set output as serialized proto containing debug info.
374         string serialized = example_debug_info.SerializeAsString();
375         output_debug_info(i) = serialized;
376       }
377     };
378 
379     // 10 is the magic number. The actual number might depend on (the number of
380     // layers in the trees) and (cpu cycles spent on each layer), but this
381     // value would work for many cases. May be tuned later.
382     const int64 cost = (last_tree + 1) * 10;
383     thread::ThreadPool* const worker_threads =
384         context->device()->tensorflow_cpu_worker_threads()->workers;
385     Shard(worker_threads->NumThreads(), worker_threads, batch_size,
386           /*cost_per_unit=*/cost, do_work);
387   }
388 
389  private:
390   int32 logits_dimension_;  // Indicates dimension of logits in the tree nodes.
391   int32 num_bucketized_features_;  // Indicates the number of features.
392 };
393 
394 REGISTER_KERNEL_BUILDER(
395     Name("BoostedTreesExampleDebugOutputs").Device(DEVICE_CPU),
396     BoostedTreesExampleDebugOutputsOp);
397 
398 }  // namespace tensorflow
399