1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
21 #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"  // NOLINT
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/lib/core/threadpool.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 namespace boosted_trees {
28 namespace models {
29 
30 // Multiple additive trees prediction model.
31 // This class does not hold state and is thread safe.
32 class MultipleAdditiveTrees {
33  public:
34   // Predict runs tree ensemble on the given batch and updates
35   // output predictions accordingly, for the given list of trees.
36   // output_leaf_indices is a pointer to a 2 dimensional tensor. If it is not
37   // nullptr, this method fills output_leaf_indices with a per-tree leaf id
38   // where each of the instances from 'features' ended up in. Its shape is num
39   // examples X num of trees.
40   static void Predict(
41       const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
42       const std::vector<int32>& trees_to_include,
43       const boosted_trees::utils::BatchFeatures& features,
44       tensorflow::thread::ThreadPool* const worker_threads,
45       tensorflow::TTypes<float>::Matrix output_predictions,
46       Tensor* const output_leaf_index);
47 };
48 
49 }  // namespace models
50 }  // namespace boosted_trees
51 }  // namespace tensorflow
52 
53 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
54