1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/core/framework/device_base.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
27 #include "tensorflow/core/kernels/boosted_trees/resources.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/refcount.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/work_sharder.h"
36
37 namespace tensorflow {
38
ConvertVectorsToMatrices(const OpInputList bucketized_features_list,std::vector<tensorflow::TTypes<int32>::ConstMatrix> & bucketized_features)39 static void ConvertVectorsToMatrices(
40 const OpInputList bucketized_features_list,
41 std::vector<tensorflow::TTypes<int32>::ConstMatrix>& bucketized_features) {
42 for (const Tensor& tensor : bucketized_features_list) {
43 if (tensor.dims() == 1) {
44 const auto v = tensor.vec<int32>();
45 bucketized_features.emplace_back(
46 TTypes<int32>::ConstMatrix(v.data(), v.size(), 1));
47 } else {
48 bucketized_features.emplace_back(tensor.matrix<int32>());
49 }
50 }
51 }
52
53 // The Op used during training time to get the predictions so far with the
54 // current ensemble being built.
55 // Expect some logits are cached from the previous step and passed through
56 // to be reused.
57 class BoostedTreesTrainingPredictOp : public OpKernel {
58 public:
BoostedTreesTrainingPredictOp(OpKernelConstruction * const context)59 explicit BoostedTreesTrainingPredictOp(OpKernelConstruction* const context)
60 : OpKernel(context) {
61 OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
62 &num_bucketized_features_));
63 OP_REQUIRES_OK(context,
64 context->GetAttr("logits_dimension", &logits_dimension_));
65 }
66
Compute(OpKernelContext * const context)67 void Compute(OpKernelContext* const context) override {
68 core::RefCountPtr<BoostedTreesEnsembleResource> resource;
69 // Get the resource.
70 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
71 &resource));
72
73 // Get the inputs.
74 OpInputList bucketized_features_list;
75 OP_REQUIRES_OK(context, context->input_list("bucketized_features",
76 &bucketized_features_list));
77 std::vector<tensorflow::TTypes<int32>::ConstMatrix> bucketized_features;
78 bucketized_features.reserve(bucketized_features_list.size());
79 ConvertVectorsToMatrices(bucketized_features_list, bucketized_features);
80 const int batch_size = bucketized_features[0].dimension(0);
81
82 const Tensor* cached_tree_ids_t;
83 OP_REQUIRES_OK(context,
84 context->input("cached_tree_ids", &cached_tree_ids_t));
85 const auto cached_tree_ids = cached_tree_ids_t->vec<int32>();
86
87 const Tensor* cached_node_ids_t;
88 OP_REQUIRES_OK(context,
89 context->input("cached_node_ids", &cached_node_ids_t));
90 const auto cached_node_ids = cached_node_ids_t->vec<int32>();
91
92 // Allocate outputs.
93 Tensor* output_partial_logits_t = nullptr;
94 OP_REQUIRES_OK(context,
95 context->allocate_output("partial_logits",
96 {batch_size, logits_dimension_},
97 &output_partial_logits_t));
98 auto output_partial_logits = output_partial_logits_t->matrix<float>();
99
100 Tensor* output_tree_ids_t = nullptr;
101 OP_REQUIRES_OK(context, context->allocate_output("tree_ids", {batch_size},
102 &output_tree_ids_t));
103 auto output_tree_ids = output_tree_ids_t->vec<int32>();
104
105 Tensor* output_node_ids_t = nullptr;
106 OP_REQUIRES_OK(context, context->allocate_output("node_ids", {batch_size},
107 &output_node_ids_t));
108 auto output_node_ids = output_node_ids_t->vec<int32>();
109
110 // Indicate that the latest tree was used.
111 const int32 latest_tree = resource->num_trees() - 1;
112
113 if (latest_tree < 0) {
114 // Ensemble was empty. Output the very first node.
115 output_node_ids.setZero();
116 output_tree_ids = cached_tree_ids;
117 // All the predictions are zeros.
118 output_partial_logits.setZero();
119 } else {
120 output_tree_ids.setConstant(latest_tree);
121 auto do_work = [&resource, &bucketized_features, &cached_tree_ids,
122 &cached_node_ids, &output_partial_logits,
123 &output_node_ids, latest_tree,
124 this](int64 start, int64 end) {
125 for (int32 i = start; i < end; ++i) {
126 int32 tree_id = cached_tree_ids(i);
127 int32 node_id = cached_node_ids(i);
128 std::vector<float> partial_tree_logits(logits_dimension_, 0.0);
129
130 if (node_id >= 0) {
131 // If the tree was pruned, returns the node id into which the
132 // current_node_id was pruned, as well the correction of the cached
133 // logit prediction.
134 resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
135 &partial_tree_logits);
136 // Logic in the loop adds the cached node value again if it is a
137 // leaf. If it is not a leaf anymore we need to subtract the old
138 // node's value. The following logic handles both of these cases.
139 const auto& node_logits = resource->node_value(tree_id, node_id);
140 if (!node_logits.empty()) {
141 DCHECK_EQ(node_logits.size(), logits_dimension_);
142 for (int32 j = 0; j < logits_dimension_; ++j) {
143 partial_tree_logits[j] -= node_logits[j];
144 }
145 }
146 } else {
147 // No cache exists, start from the very first node.
148 node_id = 0;
149 }
150 std::vector<float> partial_all_logits(logits_dimension_, 0.0);
151 while (true) {
152 if (resource->is_leaf(tree_id, node_id)) {
153 const auto& leaf_logits = resource->node_value(tree_id, node_id);
154 DCHECK_EQ(leaf_logits.size(), logits_dimension_);
155 // Tree is done
156 const float tree_weight = resource->GetTreeWeight(tree_id);
157 for (int32 j = 0; j < logits_dimension_; ++j) {
158 partial_all_logits[j] +=
159 tree_weight * (partial_tree_logits[j] + leaf_logits[j]);
160 partial_tree_logits[j] = 0;
161 }
162 // Stop if it was the latest tree.
163 if (tree_id == latest_tree) {
164 break;
165 }
166 // Move onto other trees.
167 ++tree_id;
168 node_id = 0;
169 } else {
170 node_id =
171 resource->next_node(tree_id, node_id, i, bucketized_features);
172 }
173 }
174 output_node_ids(i) = node_id;
175 for (int32 j = 0; j < logits_dimension_; ++j) {
176 output_partial_logits(i, j) = partial_all_logits[j];
177 }
178 }
179 };
180 // 30 is the magic number. The actual value might be a function of (the
181 // number of layers) * (cpu cycles spent on each layer), but this value
182 // would work for many cases. May be tuned later.
183 const int64 cost = 30;
184 thread::ThreadPool* const worker_threads =
185 context->device()->tensorflow_cpu_worker_threads()->workers;
186 Shard(worker_threads->NumThreads(), worker_threads, batch_size,
187 /*cost_per_unit=*/cost, do_work);
188 }
189 }
190
191 private:
192 int32 logits_dimension_; // the size of the output prediction vector.
193 int32 num_bucketized_features_; // Indicates the number of features.
194 };
195
196 REGISTER_KERNEL_BUILDER(Name("BoostedTreesTrainingPredict").Device(DEVICE_CPU),
197 BoostedTreesTrainingPredictOp);
198
199 // The Op to get the predictions at the evaluation/inference time.
200 class BoostedTreesPredictOp : public OpKernel {
201 public:
BoostedTreesPredictOp(OpKernelConstruction * const context)202 explicit BoostedTreesPredictOp(OpKernelConstruction* const context)
203 : OpKernel(context) {
204 OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
205 &num_bucketized_features_));
206 OP_REQUIRES_OK(context,
207 context->GetAttr("logits_dimension", &logits_dimension_));
208 }
209
Compute(OpKernelContext * const context)210 void Compute(OpKernelContext* const context) override {
211 core::RefCountPtr<BoostedTreesEnsembleResource> resource;
212 // Get the resource.
213 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
214 &resource));
215
216 // Get the inputs.
217 OpInputList bucketized_features_list;
218 OP_REQUIRES_OK(context, context->input_list("bucketized_features",
219 &bucketized_features_list));
220 std::vector<tensorflow::TTypes<int32>::ConstMatrix> bucketized_features;
221 bucketized_features.reserve(bucketized_features_list.size());
222 ConvertVectorsToMatrices(bucketized_features_list, bucketized_features);
223 const int batch_size = bucketized_features[0].dimension(0);
224
225 // Allocate outputs.
226 Tensor* output_logits_t = nullptr;
227 OP_REQUIRES_OK(context, context->allocate_output(
228 "logits", {batch_size, logits_dimension_},
229 &output_logits_t));
230 auto output_logits = output_logits_t->matrix<float>();
231
232 // Return zero logits if it's an empty ensemble.
233 if (resource->num_trees() <= 0) {
234 output_logits.setZero();
235 return;
236 }
237
238 const int32 last_tree = resource->num_trees() - 1;
239 auto do_work = [&resource, &bucketized_features, &output_logits, last_tree,
240 this](int64 start, int64 end) {
241 for (int32 i = start; i < end; ++i) {
242 std::vector<float> tree_logits(logits_dimension_, 0.0);
243 int32 tree_id = 0;
244 int32 node_id = 0;
245 while (true) {
246 if (resource->is_leaf(tree_id, node_id)) {
247 const float tree_weight = resource->GetTreeWeight(tree_id);
248 const auto& leaf_logits = resource->node_value(tree_id, node_id);
249 DCHECK_EQ(leaf_logits.size(), logits_dimension_);
250 for (int32 j = 0; j < logits_dimension_; ++j) {
251 tree_logits[j] += tree_weight * leaf_logits[j];
252 }
253 // Stop if it was the last tree.
254 if (tree_id == last_tree) {
255 break;
256 }
257 // Move onto other trees.
258 ++tree_id;
259 node_id = 0;
260 } else {
261 node_id =
262 resource->next_node(tree_id, node_id, i, bucketized_features);
263 }
264 }
265 for (int32 j = 0; j < logits_dimension_; ++j) {
266 output_logits(i, j) = tree_logits[j];
267 }
268 }
269 };
270 // 10 is the magic number. The actual number might depend on (the number of
271 // layers in the trees) and (cpu cycles spent on each layer), but this
272 // value would work for many cases. May be tuned later.
273 const int64 cost = (last_tree + 1) * 10;
274 thread::ThreadPool* const worker_threads =
275 context->device()->tensorflow_cpu_worker_threads()->workers;
276 Shard(worker_threads->NumThreads(), worker_threads, batch_size,
277 /*cost_per_unit=*/cost, do_work);
278 }
279
280 private:
281 int32
282 logits_dimension_; // Indicates the size of the output prediction vector.
283 int32 num_bucketized_features_; // Indicates the number of features.
284 };
285
286 REGISTER_KERNEL_BUILDER(Name("BoostedTreesPredict").Device(DEVICE_CPU),
287 BoostedTreesPredictOp);
288
289 // The Op that returns debugging/model interpretability outputs for each
290 // example. Currently it outputs the split feature ids and logits after each
291 // split along the decision path for each example. This will be used to compute
292 // directional feature contributions at predict time for an arbitrary activation
293 // function.
294 // TODO(crawles): return in proto 1) Node IDs for ensemble prediction path
295 // 2) Leaf node IDs.
296 class BoostedTreesExampleDebugOutputsOp : public OpKernel {
297 public:
BoostedTreesExampleDebugOutputsOp(OpKernelConstruction * const context)298 explicit BoostedTreesExampleDebugOutputsOp(
299 OpKernelConstruction* const context)
300 : OpKernel(context) {
301 OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
302 &num_bucketized_features_));
303 OP_REQUIRES_OK(context,
304 context->GetAttr("logits_dimension", &logits_dimension_));
305 OP_REQUIRES(context, logits_dimension_ == 1,
306 errors::InvalidArgument(
307 "Currently only one dimensional outputs are supported."));
308 }
309
Compute(OpKernelContext * const context)310 void Compute(OpKernelContext* const context) override {
311 core::RefCountPtr<BoostedTreesEnsembleResource> resource;
312 // Get the resource.
313 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
314 &resource));
315
316 // Get the inputs.
317 OpInputList bucketized_features_list;
318 OP_REQUIRES_OK(context, context->input_list("bucketized_features",
319 &bucketized_features_list));
320 std::vector<tensorflow::TTypes<int32>::ConstMatrix> bucketized_features;
321 bucketized_features.reserve(bucketized_features_list.size());
322 ConvertVectorsToMatrices(bucketized_features_list, bucketized_features);
323 const int batch_size = bucketized_features[0].dimension(0);
324
325 // We need to get the feature ids used for splitting and the logits after
326 // each split. We will use these to calculate the changes in the prediction
327 // (contributions) for an arbitrary activation function (done in Python) and
328 // attribute them to the associated feature ids. We will store these in
329 // a proto below.
330 Tensor* output_debug_info_t = nullptr;
331 OP_REQUIRES_OK(
332 context, context->allocate_output("examples_debug_outputs_serialized",
333 {batch_size}, &output_debug_info_t));
334 // Will contain serialized protos, per example.
335 auto output_debug_info = output_debug_info_t->flat<tstring>();
336 const int32 last_tree = resource->num_trees() - 1;
337
338 // For each given example, traverse through all trees keeping track of the
339 // features used to split and the associated logits at each point along the
340 // path. Note: feature_ids has one less value than logits_path because the
341 // first value of each logit path will be the bias.
342 auto do_work = [&resource, &bucketized_features, &output_debug_info,
343 last_tree](int64 start, int64 end) {
344 for (int32 i = start; i < end; ++i) {
345 // Proto to store debug outputs, per example.
346 boosted_trees::DebugOutput example_debug_info;
347 // Initial bias prediction. E.g., prediction based off training mean.
348 const auto& tree_logits = resource->node_value(0, 0);
349 DCHECK_EQ(tree_logits.size(), 1);
350 float tree_logit = resource->GetTreeWeight(0) * tree_logits[0];
351 example_debug_info.add_logits_path(tree_logit);
352 int32 node_id = 0;
353 int32 tree_id = 0;
354 int32 feature_id;
355 float past_trees_logit = 0; // Sum of leaf logits from prior trees.
356 // Go through each tree and populate proto.
357 while (tree_id <= last_tree) {
358 if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees.
359 // Accumulate tree_logits only if the leaf is non-root, but do so
360 // for bias tree.
361 if (tree_id == 0 || node_id > 0) {
362 past_trees_logit += tree_logit;
363 }
364 example_debug_info.add_leaf_node_ids(node_id);
365 ++tree_id;
366 node_id = 0;
367 } else { // Add to proto.
368 // Feature id used to split.
369 feature_id = resource->feature_id(tree_id, node_id);
370 example_debug_info.add_feature_ids(feature_id);
371 // Get logit after split.
372 node_id =
373 resource->next_node(tree_id, node_id, i, bucketized_features);
374 const auto& tree_logits = resource->node_value(tree_id, node_id);
375 DCHECK_EQ(tree_logits.size(), 1);
376 tree_logit = resource->GetTreeWeight(tree_id) * tree_logits[0];
377 // Output logit incorporates sum of leaf logits from prior trees.
378 example_debug_info.add_logits_path(tree_logit + past_trees_logit);
379 }
380 }
381 // Set output as serialized proto containing debug info.
382 string serialized = example_debug_info.SerializeAsString();
383 output_debug_info(i) = serialized;
384 }
385 };
386
387 // 10 is the magic number. The actual number might depend on (the number of
388 // layers in the trees) and (cpu cycles spent on each layer), but this
389 // value would work for many cases. May be tuned later.
390 const int64 cost = (last_tree + 1) * 10;
391 thread::ThreadPool* const worker_threads =
392 context->device()->tensorflow_cpu_worker_threads()->workers;
393 Shard(worker_threads->NumThreads(), worker_threads, batch_size,
394 /*cost_per_unit=*/cost, do_work);
395 }
396
397 private:
398 int32 logits_dimension_; // Indicates dimension of logits in the tree nodes.
399 int32 num_bucketized_features_; // Indicates the number of features.
400 };
401
402 REGISTER_KERNEL_BUILDER(
403 Name("BoostedTreesExampleDebugOutputs").Device(DEVICE_CPU),
404 BoostedTreesExampleDebugOutputsOp);
405
406 } // namespace tensorflow
407