1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/device_base.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
27 #include "tensorflow/core/kernels/boosted_trees/resources.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/refcount.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 
37 namespace tensorflow {
38 
ConvertVectorsToMatrices(const OpInputList bucketized_features_list,std::vector<tensorflow::TTypes<int32>::ConstMatrix> & bucketized_features)39 static void ConvertVectorsToMatrices(
40     const OpInputList bucketized_features_list,
41     std::vector<tensorflow::TTypes<int32>::ConstMatrix>& bucketized_features) {
42   for (const Tensor& tensor : bucketized_features_list) {
43     if (tensor.dims() == 1) {
44       const auto v = tensor.vec<int32>();
45       bucketized_features.emplace_back(
46           TTypes<int32>::ConstMatrix(v.data(), v.size(), 1));
47     } else {
48       bucketized_features.emplace_back(tensor.matrix<int32>());
49     }
50   }
51 }
52 
53 // The Op used during training time to get the predictions so far with the
54 // current ensemble being built.
55 // Expect some logits are cached from the previous step and passed through
56 // to be reused.
57 class BoostedTreesTrainingPredictOp : public OpKernel {
58  public:
BoostedTreesTrainingPredictOp(OpKernelConstruction * const context)59   explicit BoostedTreesTrainingPredictOp(OpKernelConstruction* const context)
60       : OpKernel(context) {
61     OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
62                                              &num_bucketized_features_));
63     OP_REQUIRES_OK(context,
64                    context->GetAttr("logits_dimension", &logits_dimension_));
65   }
66 
Compute(OpKernelContext * const context)67   void Compute(OpKernelContext* const context) override {
68     core::RefCountPtr<BoostedTreesEnsembleResource> resource;
69     // Get the resource.
70     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
71                                            &resource));
72 
73     // Get the inputs.
74     OpInputList bucketized_features_list;
75     OP_REQUIRES_OK(context, context->input_list("bucketized_features",
76                                                 &bucketized_features_list));
77     std::vector<tensorflow::TTypes<int32>::ConstMatrix> bucketized_features;
78     bucketized_features.reserve(bucketized_features_list.size());
79     ConvertVectorsToMatrices(bucketized_features_list, bucketized_features);
80     const int batch_size = bucketized_features[0].dimension(0);
81 
82     const Tensor* cached_tree_ids_t;
83     OP_REQUIRES_OK(context,
84                    context->input("cached_tree_ids", &cached_tree_ids_t));
85     const auto cached_tree_ids = cached_tree_ids_t->vec<int32>();
86 
87     const Tensor* cached_node_ids_t;
88     OP_REQUIRES_OK(context,
89                    context->input("cached_node_ids", &cached_node_ids_t));
90     const auto cached_node_ids = cached_node_ids_t->vec<int32>();
91 
92     // Allocate outputs.
93     Tensor* output_partial_logits_t = nullptr;
94     OP_REQUIRES_OK(context,
95                    context->allocate_output("partial_logits",
96                                             {batch_size, logits_dimension_},
97                                             &output_partial_logits_t));
98     auto output_partial_logits = output_partial_logits_t->matrix<float>();
99 
100     Tensor* output_tree_ids_t = nullptr;
101     OP_REQUIRES_OK(context, context->allocate_output("tree_ids", {batch_size},
102                                                      &output_tree_ids_t));
103     auto output_tree_ids = output_tree_ids_t->vec<int32>();
104 
105     Tensor* output_node_ids_t = nullptr;
106     OP_REQUIRES_OK(context, context->allocate_output("node_ids", {batch_size},
107                                                      &output_node_ids_t));
108     auto output_node_ids = output_node_ids_t->vec<int32>();
109 
110     // Indicate that the latest tree was used.
111     const int32 latest_tree = resource->num_trees() - 1;
112 
113     if (latest_tree < 0) {
114       // Ensemble was empty. Output the very first node.
115       output_node_ids.setZero();
116       output_tree_ids = cached_tree_ids;
117       // All the predictions are zeros.
118       output_partial_logits.setZero();
119     } else {
120       output_tree_ids.setConstant(latest_tree);
121       auto do_work = [&resource, &bucketized_features, &cached_tree_ids,
122                       &cached_node_ids, &output_partial_logits,
123                       &output_node_ids, latest_tree,
124                       this](int64 start, int64 end) {
125         for (int32 i = start; i < end; ++i) {
126           int32 tree_id = cached_tree_ids(i);
127           int32 node_id = cached_node_ids(i);
128           std::vector<float> partial_tree_logits(logits_dimension_, 0.0);
129 
130           if (node_id >= 0) {
131             // If the tree was pruned, returns the node id into which the
132             // current_node_id was pruned, as well the correction of the cached
133             // logit prediction.
134             resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
135                                              &partial_tree_logits);
136             // Logic in the loop adds the cached node value again if it is a
137             // leaf. If it is not a leaf anymore we need to subtract the old
138             // node's value. The following logic handles both of these cases.
139             const auto& node_logits = resource->node_value(tree_id, node_id);
140             if (!node_logits.empty()) {
141               DCHECK_EQ(node_logits.size(), logits_dimension_);
142               for (int32 j = 0; j < logits_dimension_; ++j) {
143                 partial_tree_logits[j] -= node_logits[j];
144               }
145             }
146           } else {
147             // No cache exists, start from the very first node.
148             node_id = 0;
149           }
150           std::vector<float> partial_all_logits(logits_dimension_, 0.0);
151           while (true) {
152             if (resource->is_leaf(tree_id, node_id)) {
153               const auto& leaf_logits = resource->node_value(tree_id, node_id);
154               DCHECK_EQ(leaf_logits.size(), logits_dimension_);
155               // Tree is done
156               const float tree_weight = resource->GetTreeWeight(tree_id);
157               for (int32 j = 0; j < logits_dimension_; ++j) {
158                 partial_all_logits[j] +=
159                     tree_weight * (partial_tree_logits[j] + leaf_logits[j]);
160                 partial_tree_logits[j] = 0;
161               }
162               // Stop if it was the latest tree.
163               if (tree_id == latest_tree) {
164                 break;
165               }
166               // Move onto other trees.
167               ++tree_id;
168               node_id = 0;
169             } else {
170               node_id =
171                   resource->next_node(tree_id, node_id, i, bucketized_features);
172             }
173           }
174           output_node_ids(i) = node_id;
175           for (int32 j = 0; j < logits_dimension_; ++j) {
176             output_partial_logits(i, j) = partial_all_logits[j];
177           }
178         }
179       };
180       // 30 is the magic number. The actual value might be a function of (the
181       // number of layers) * (cpu cycles spent on each layer), but this value
182       // would work for many cases. May be tuned later.
183       const int64 cost = 30;
184       thread::ThreadPool* const worker_threads =
185           context->device()->tensorflow_cpu_worker_threads()->workers;
186       Shard(worker_threads->NumThreads(), worker_threads, batch_size,
187             /*cost_per_unit=*/cost, do_work);
188     }
189   }
190 
191  private:
192   int32 logits_dimension_;         // the size of the output prediction vector.
193   int32 num_bucketized_features_;  // Indicates the number of features.
194 };
195 
196 REGISTER_KERNEL_BUILDER(Name("BoostedTreesTrainingPredict").Device(DEVICE_CPU),
197                         BoostedTreesTrainingPredictOp);
198 
199 // The Op to get the predictions at the evaluation/inference time.
200 class BoostedTreesPredictOp : public OpKernel {
201  public:
BoostedTreesPredictOp(OpKernelConstruction * const context)202   explicit BoostedTreesPredictOp(OpKernelConstruction* const context)
203       : OpKernel(context) {
204     OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
205                                              &num_bucketized_features_));
206     OP_REQUIRES_OK(context,
207                    context->GetAttr("logits_dimension", &logits_dimension_));
208   }
209 
Compute(OpKernelContext * const context)210   void Compute(OpKernelContext* const context) override {
211     core::RefCountPtr<BoostedTreesEnsembleResource> resource;
212     // Get the resource.
213     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
214                                            &resource));
215 
216     // Get the inputs.
217     OpInputList bucketized_features_list;
218     OP_REQUIRES_OK(context, context->input_list("bucketized_features",
219                                                 &bucketized_features_list));
220     std::vector<tensorflow::TTypes<int32>::ConstMatrix> bucketized_features;
221     bucketized_features.reserve(bucketized_features_list.size());
222     ConvertVectorsToMatrices(bucketized_features_list, bucketized_features);
223     const int batch_size = bucketized_features[0].dimension(0);
224 
225     // Allocate outputs.
226     Tensor* output_logits_t = nullptr;
227     OP_REQUIRES_OK(context, context->allocate_output(
228                                 "logits", {batch_size, logits_dimension_},
229                                 &output_logits_t));
230     auto output_logits = output_logits_t->matrix<float>();
231 
232     // Return zero logits if it's an empty ensemble.
233     if (resource->num_trees() <= 0) {
234       output_logits.setZero();
235       return;
236     }
237 
238     const int32 last_tree = resource->num_trees() - 1;
239     auto do_work = [&resource, &bucketized_features, &output_logits, last_tree,
240                     this](int64 start, int64 end) {
241       for (int32 i = start; i < end; ++i) {
242         std::vector<float> tree_logits(logits_dimension_, 0.0);
243         int32 tree_id = 0;
244         int32 node_id = 0;
245         while (true) {
246           if (resource->is_leaf(tree_id, node_id)) {
247             const float tree_weight = resource->GetTreeWeight(tree_id);
248             const auto& leaf_logits = resource->node_value(tree_id, node_id);
249             DCHECK_EQ(leaf_logits.size(), logits_dimension_);
250             for (int32 j = 0; j < logits_dimension_; ++j) {
251               tree_logits[j] += tree_weight * leaf_logits[j];
252             }
253             // Stop if it was the last tree.
254             if (tree_id == last_tree) {
255               break;
256             }
257             // Move onto other trees.
258             ++tree_id;
259             node_id = 0;
260           } else {
261             node_id =
262                 resource->next_node(tree_id, node_id, i, bucketized_features);
263           }
264         }
265         for (int32 j = 0; j < logits_dimension_; ++j) {
266           output_logits(i, j) = tree_logits[j];
267         }
268       }
269     };
270     // 10 is the magic number. The actual number might depend on (the number of
271     // layers in the trees) and (cpu cycles spent on each layer), but this
272     // value would work for many cases. May be tuned later.
273     const int64 cost = (last_tree + 1) * 10;
274     thread::ThreadPool* const worker_threads =
275         context->device()->tensorflow_cpu_worker_threads()->workers;
276     Shard(worker_threads->NumThreads(), worker_threads, batch_size,
277           /*cost_per_unit=*/cost, do_work);
278   }
279 
280  private:
281   int32
282       logits_dimension_;  // Indicates the size of the output prediction vector.
283   int32 num_bucketized_features_;  // Indicates the number of features.
284 };
285 
286 REGISTER_KERNEL_BUILDER(Name("BoostedTreesPredict").Device(DEVICE_CPU),
287                         BoostedTreesPredictOp);
288 
289 // The Op that returns debugging/model interpretability outputs for each
290 // example. Currently it outputs the split feature ids and logits after each
291 // split along the decision path for each example. This will be used to compute
292 // directional feature contributions at predict time for an arbitrary activation
293 // function.
294 // TODO(crawles): return in proto 1) Node IDs for ensemble prediction path
295 // 2) Leaf node IDs.
296 class BoostedTreesExampleDebugOutputsOp : public OpKernel {
297  public:
BoostedTreesExampleDebugOutputsOp(OpKernelConstruction * const context)298   explicit BoostedTreesExampleDebugOutputsOp(
299       OpKernelConstruction* const context)
300       : OpKernel(context) {
301     OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
302                                              &num_bucketized_features_));
303     OP_REQUIRES_OK(context,
304                    context->GetAttr("logits_dimension", &logits_dimension_));
305     OP_REQUIRES(context, logits_dimension_ == 1,
306                 errors::InvalidArgument(
307                     "Currently only one dimensional outputs are supported."));
308   }
309 
Compute(OpKernelContext * const context)310   void Compute(OpKernelContext* const context) override {
311     core::RefCountPtr<BoostedTreesEnsembleResource> resource;
312     // Get the resource.
313     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
314                                            &resource));
315 
316     // Get the inputs.
317     OpInputList bucketized_features_list;
318     OP_REQUIRES_OK(context, context->input_list("bucketized_features",
319                                                 &bucketized_features_list));
320     std::vector<tensorflow::TTypes<int32>::ConstMatrix> bucketized_features;
321     bucketized_features.reserve(bucketized_features_list.size());
322     ConvertVectorsToMatrices(bucketized_features_list, bucketized_features);
323     const int batch_size = bucketized_features[0].dimension(0);
324 
325     // We need to get the feature ids used for splitting and the logits after
326     // each split. We will use these to calculate the changes in the prediction
327     // (contributions) for an arbitrary activation function (done in Python) and
328     // attribute them to the associated feature ids. We will store these in
329     // a proto below.
330     Tensor* output_debug_info_t = nullptr;
331     OP_REQUIRES_OK(
332         context, context->allocate_output("examples_debug_outputs_serialized",
333                                           {batch_size}, &output_debug_info_t));
334     // Will contain serialized protos, per example.
335     auto output_debug_info = output_debug_info_t->flat<tstring>();
336     const int32 last_tree = resource->num_trees() - 1;
337 
338     // For each given example, traverse through all trees keeping track of the
339     // features used to split and the associated logits at each point along the
340     // path. Note: feature_ids has one less value than logits_path because the
341     // first value of each logit path will be the bias.
342     auto do_work = [&resource, &bucketized_features, &output_debug_info,
343                     last_tree](int64 start, int64 end) {
344       for (int32 i = start; i < end; ++i) {
345         // Proto to store debug outputs, per example.
346         boosted_trees::DebugOutput example_debug_info;
347         // Initial bias prediction. E.g., prediction based off training mean.
348         const auto& tree_logits = resource->node_value(0, 0);
349         DCHECK_EQ(tree_logits.size(), 1);
350         float tree_logit = resource->GetTreeWeight(0) * tree_logits[0];
351         example_debug_info.add_logits_path(tree_logit);
352         int32 node_id = 0;
353         int32 tree_id = 0;
354         int32 feature_id;
355         float past_trees_logit = 0;  // Sum of leaf logits from prior trees.
356         // Go through each tree and populate proto.
357         while (tree_id <= last_tree) {
358           if (resource->is_leaf(tree_id, node_id)) {  // Move onto other trees.
359             // Accumulate tree_logits only if the leaf is non-root, but do so
360             // for bias tree.
361             if (tree_id == 0 || node_id > 0) {
362               past_trees_logit += tree_logit;
363             }
364             example_debug_info.add_leaf_node_ids(node_id);
365             ++tree_id;
366             node_id = 0;
367           } else {  // Add to proto.
368             // Feature id used to split.
369             feature_id = resource->feature_id(tree_id, node_id);
370             example_debug_info.add_feature_ids(feature_id);
371             // Get logit after split.
372             node_id =
373                 resource->next_node(tree_id, node_id, i, bucketized_features);
374             const auto& tree_logits = resource->node_value(tree_id, node_id);
375             DCHECK_EQ(tree_logits.size(), 1);
376             tree_logit = resource->GetTreeWeight(tree_id) * tree_logits[0];
377             // Output logit incorporates sum of leaf logits from prior trees.
378             example_debug_info.add_logits_path(tree_logit + past_trees_logit);
379           }
380         }
381         // Set output as serialized proto containing debug info.
382         string serialized = example_debug_info.SerializeAsString();
383         output_debug_info(i) = serialized;
384       }
385     };
386 
387     // 10 is the magic number. The actual number might depend on (the number of
388     // layers in the trees) and (cpu cycles spent on each layer), but this
389     // value would work for many cases. May be tuned later.
390     const int64 cost = (last_tree + 1) * 10;
391     thread::ThreadPool* const worker_threads =
392         context->device()->tensorflow_cpu_worker_threads()->workers;
393     Shard(worker_threads->NumThreads(), worker_threads, batch_size,
394           /*cost_per_unit=*/cost, do_work);
395   }
396 
397  private:
398   int32 logits_dimension_;  // Indicates dimension of logits in the tree nodes.
399   int32 num_bucketized_features_;  // Indicates the number of features.
400 };
401 
402 REGISTER_KERNEL_BUILDER(
403     Name("BoostedTreesExampleDebugOutputs").Device(DEVICE_CPU),
404     BoostedTreesExampleDebugOutputsOp);
405 
406 }  // namespace tensorflow
407