1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
17 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
18 
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/platform/mutex.h"
21 #include "tensorflow/core/platform/protobuf.h"
22 
23 namespace tensorflow {
24 
25 // Forward declaration for proto class TreeEnsemble
26 namespace boosted_trees {
27 class TreeEnsemble;
28 }  // namespace boosted_trees
29 
30 // A StampedResource is a resource that has a stamp token associated with it.
31 // Before reading from or applying updates to the resource, the stamp should
32 // be checked to verify that the update is not stale.
33 class StampedResource : public ResourceBase {
34  public:
StampedResource()35   StampedResource() : stamp_(-1) {}
36 
is_stamp_valid(int64 stamp)37   bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; }
38 
stamp()39   int64 stamp() const { return stamp_; }
set_stamp(int64 stamp)40   void set_stamp(int64 stamp) { stamp_ = stamp; }
41 
42  private:
43   int64 stamp_;
44 };
45 
46 // Keep a tree ensemble in memory for efficient evaluation and mutation.
47 class BoostedTreesEnsembleResource : public StampedResource {
48  public:
49   BoostedTreesEnsembleResource();
50 
51   string DebugString() const override;
52 
53   bool InitFromSerialized(const string& serialized, const int64 stamp_token);
54 
55   string SerializeAsString() const;
56 
57   int32 num_trees() const;
58 
59   // Find the next node to which the example (specified by index_in_batch)
60   // traverses down from the current node indicated by tree_id and node_id.
61   // Args:
62   //   tree_id: the index of the tree in the ensemble.
63   //   node_id: the index of the node within the tree.
64   //   index_in_batch: the index of the example within the batch (relevant to
65   //       the index of the row to read in each bucketized_features).
66   //   bucketized_features: vector of feature Vectors.
67   int32 next_node(
68       const int32 tree_id, const int32 node_id, const int32 index_in_batch,
69       const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const;
70 
71   std::vector<float> node_value(const int32 tree_id, const int32 node_id) const;
72 
73   void set_node_value(const int32 tree_id, const int32 node_id,
74                       const float logits);
75 
76   int32 GetNumLayersGrown(const int32 tree_id) const;
77 
78   void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const;
79 
80   void UpdateLastLayerNodesRange(const int32 node_range_start,
81                                  int32 node_range_end) const;
82 
83   void GetLastLayerNodesRange(int32* node_range_start,
84                               int32* node_range_end) const;
85 
86   int64 GetNumNodes(const int32 tree_id);
87 
88   void UpdateGrowingMetadata() const;
89 
90   int32 GetNumLayersAttempted();
91 
92   bool is_leaf(const int32 tree_id, const int32 node_id) const;
93 
94   int32 feature_id(const int32 tree_id, const int32 node_id) const;
95 
96   int32 bucket_threshold(const int32 tree_id, const int32 node_id) const;
97 
98   int32 left_id(const int32 tree_id, const int32 node_id) const;
99 
100   int32 right_id(const int32 tree_id, const int32 node_id) const;
101 
102   // Add a tree to the ensemble and returns a new tree_id.
103   int32 AddNewTree(const float weight);
104 
105   // Adds new tree with one node to the ensemble and sets node's value to logits
106   int32 AddNewTreeWithLogits(const float weight, const float logits);
107 
108   // Grows the tree by adding a split and leaves.
109   void AddBucketizedSplitNode(const int32 tree_id, const int32 node_id,
110                               const int32 feature_id, const int32 threshold,
111                               const float gain, const float left_contrib,
112                               const float right_contrib, int32* left_node_id,
113                               int32* right_node_id);
114 
115   // Retrieves tree weights and returns as a vector.
116   // It involves a copy, so should be called only sparingly (like once per
117   // iteration, not per example).
118   std::vector<float> GetTreeWeights() const;
119 
120   float GetTreeWeight(const int32 tree_id) const;
121 
122   float IsTreeFinalized(const int32 tree_id) const;
123 
124   float IsTreePostPruned(const int32 tree_id) const;
125 
126   void SetIsFinalized(const int32 tree_id, const bool is_finalized);
127 
128   // Sets the weight of i'th tree.
129   void SetTreeWeight(const int32 tree_id, const float weight);
130 
131   // Resets the resource and frees the protos in arena.
132   // Caller needs to hold the mutex lock while calling this.
133   virtual void Reset();
134 
135   void PostPruneTree(const int32 current_tree);
136 
137   // For a given node, returns the id in a pruned tree, as well as correction
138   // to the cached prediction that should be applied. If tree was not
139   // post-pruned, current_node_id will be equal to initial_node_id and logit
140   // update will be equal to zero.
141   void GetPostPruneCorrection(const int32 tree_id, const int32 initial_node_id,
142                               int32* current_node_id,
143                               float* logit_update) const;
get_mutex()144   mutex* get_mutex() { return &mu_; }
145 
146  private:
147   // Helper method to check whether a node is a terminal node in that it
148   // only has leaf nodes as children.
149   bool IsTerminalSplitNode(const int32 tree_id, const int32 node_id) const;
150 
151   // For each pruned node, finds the leaf where it finally ended up and
152   // calculates the total update from that pruned node prediction.
153   void CalculateParentAndLogitUpdate(
154       const int32 start_node_id,
155       const std::vector<std::pair<int32, float>>& nodes_change,
156       int32* parent_id, float* change) const;
157 
158   // Helper method to collect the information to be used to prune some nodes in
159   // the tree.
160   void RecursivelyDoPostPrunePreparation(
161       const int32 tree_id, const int32 node_id,
162       std::vector<int32>* nodes_to_delete,
163       std::vector<std::pair<int32, float>>* nodes_meta);
164 
165  protected:
166   protobuf::Arena arena_;
167   mutex mu_;
168   boosted_trees::TreeEnsemble* tree_ensemble_;
169 };
170 
171 }  // namespace tensorflow
172 
173 #endif  // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
174