1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
17 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
18 
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/protobuf.h"
23 
24 namespace tensorflow {
25 
26 // Forward declaration for proto class TreeEnsemble
27 namespace boosted_trees {
28 class TreeEnsemble;
29 class Node;
30 }  // namespace boosted_trees
31 
32 // A StampedResource is a resource that has a stamp token associated with it.
33 // Before reading from or applying updates to the resource, the stamp should
34 // be checked to verify that the update is not stale.
35 class StampedResource : public ResourceBase {
36  public:
StampedResource()37   StampedResource() : stamp_(-1) {}
38 
is_stamp_valid(int64 stamp)39   bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; }
40 
stamp()41   int64 stamp() const { return stamp_; }
set_stamp(int64 stamp)42   void set_stamp(int64 stamp) { stamp_ = stamp; }
43 
44  private:
45   int64 stamp_;
46 };
47 
48 // Keep a tree ensemble in memory for efficient evaluation and mutation.
49 class BoostedTreesEnsembleResource : public StampedResource {
50  public:
51   BoostedTreesEnsembleResource();
52 
53   string DebugString() const override;
54 
55   bool InitFromSerialized(const string& serialized, const int64 stamp_token);
56 
57   string SerializeAsString() const;
58 
59   int32 num_trees() const;
60 
61   // Find the next node to which the example (specified by index_in_batch)
62   // traverses down from the current node indicated by tree_id and node_id.
63   // Args:
64   //   tree_id: the index of the tree in the ensemble.
65   //   node_id: the index of the node within the tree.
66   //   index_in_batch: the index of the example within the batch (relevant to
67   //       the index of the row to read in each bucketized_features).
68   //   bucketized_features: vector of feature Vectors.
69   int32 next_node(
70       const int32 tree_id, const int32 node_id, const int32 index_in_batch,
71       const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const;
72 
73   std::vector<float> node_value(const int32 tree_id, const int32 node_id) const;
74 
75   void set_node_value(const int32 tree_id, const int32 node_id,
76                       const float logits);
77 
78   int32 GetNumLayersGrown(const int32 tree_id) const;
79 
80   void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const;
81 
82   void UpdateLastLayerNodesRange(const int32 node_range_start,
83                                  int32 node_range_end) const;
84 
85   void GetLastLayerNodesRange(int32* node_range_start,
86                               int32* node_range_end) const;
87 
88   int64 GetNumNodes(const int32 tree_id);
89 
90   void UpdateGrowingMetadata() const;
91 
92   int32 GetNumLayersAttempted();
93 
94   bool is_leaf(const int32 tree_id, const int32 node_id) const;
95 
96   int32 feature_id(const int32 tree_id, const int32 node_id) const;
97 
98   int32 bucket_threshold(const int32 tree_id, const int32 node_id) const;
99 
100   int32 left_id(const int32 tree_id, const int32 node_id) const;
101 
102   int32 right_id(const int32 tree_id, const int32 node_id) const;
103 
104   // Add a tree to the ensemble and returns a new tree_id.
105   int32 AddNewTree(const float weight, const int32 logits_dimension);
106 
107   // Adds new tree with one node to the ensemble and sets node's value to logits
108   int32 AddNewTreeWithLogits(const float weight,
109                              const std::vector<float>& logits,
110                              const int32 logits_dimension);
111 
112   // Grows the tree by adding a bucketized split and leaves.
113   void AddBucketizedSplitNode(
114       const int32 tree_id,
115       const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
116       const int32 logits_dimension, int32* left_node_id, int32* right_node_id);
117 
118   // Grows the tree by adding a categorical split and leaves.
119   void AddCategoricalSplitNode(
120       const int32 tree_id,
121       const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
122       const int32 logits_dimension, int32* left_node_id, int32* right_node_id);
123 
124   // Retrieves tree weights and returns as a vector.
125   // It involves a copy, so should be called only sparingly (like once per
126   // iteration, not per example).
127   std::vector<float> GetTreeWeights() const;
128 
129   float GetTreeWeight(const int32 tree_id) const;
130 
131   float IsTreeFinalized(const int32 tree_id) const;
132 
133   float IsTreePostPruned(const int32 tree_id) const;
134 
135   void SetIsFinalized(const int32 tree_id, const bool is_finalized);
136 
137   // Sets the weight of i'th tree.
138   void SetTreeWeight(const int32 tree_id, const float weight);
139 
140   // Resets the resource and frees the protos in arena.
141   // Caller needs to hold the mutex lock while calling this.
142   virtual void Reset();
143 
144   void PostPruneTree(const int32 current_tree, const int32 logits_dimension);
145 
146   // For a given node, returns the id in a pruned tree, as well as correction
147   // to the cached prediction that should be applied. If tree was not
148   // post-pruned, current_node_id will be equal to initial_node_id and logit
149   // update will be equal to zero.
150   void GetPostPruneCorrection(const int32 tree_id, const int32 initial_node_id,
151                               int32* current_node_id,
152                               std::vector<float>* logit_updates) const;
get_mutex()153   mutex* get_mutex() { return &mu_; }
154 
155  private:
156   // Helper method to check whether a node is a terminal node in that it
157   // only has leaf nodes as children.
158   bool IsTerminalSplitNode(const int32 tree_id, const int32 node_id) const;
159 
160   // For each pruned node, finds the leaf where it finally ended up and
161   // calculates the total update from that pruned node prediction.
162   void CalculateParentAndLogitUpdate(
163       const int32 start_node_id,
164       const std::vector<std::pair<int32, std::vector<float>>>& nodes_change,
165       int32* parent_id, std::vector<float>* change) const;
166 
167   // Helper method to collect the information to be used to prune some nodes in
168   // the tree.
169   void RecursivelyDoPostPrunePreparation(
170       const int32 tree_id, const int32 node_id,
171       std::vector<int32>* nodes_to_delete,
172       std::vector<std::pair<int32, std::vector<float>>>* nodes_meta);
173 
174  protected:
175   protobuf::Arena arena_;
176   mutex mu_;
177   boosted_trees::TreeEnsemble* tree_ensemble_;
178 
179   boosted_trees::Node* AddLeafNodes(
180       int32 tree_id,
181       const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
182       const int32 logits_dimension, int32* left_node_id, int32* right_node_id);
183 };
184 
185 }  // namespace tensorflow
186 
187 #endif  // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
188