1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ 17 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ 18 19 #include "tensorflow/core/framework/resource_mgr.h" 20 #include "tensorflow/core/platform/mutex.h" 21 #include "tensorflow/core/platform/protobuf.h" 22 23 namespace tensorflow { 24 25 // Forward declaration for proto class TreeEnsemble 26 namespace boosted_trees { 27 class TreeEnsemble; 28 } // namespace boosted_trees 29 30 // A StampedResource is a resource that has a stamp token associated with it. 31 // Before reading from or applying updates to the resource, the stamp should 32 // be checked to verify that the update is not stale. 33 class StampedResource : public ResourceBase { 34 public: StampedResource()35 StampedResource() : stamp_(-1) {} 36 is_stamp_valid(int64 stamp)37 bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; } 38 stamp()39 int64 stamp() const { return stamp_; } set_stamp(int64 stamp)40 void set_stamp(int64 stamp) { stamp_ = stamp; } 41 42 private: 43 int64 stamp_; 44 }; 45 46 // Keep a tree ensemble in memory for efficient evaluation and mutation. 47 class BoostedTreesEnsembleResource : public StampedResource { 48 public: 49 BoostedTreesEnsembleResource(); 50 51 string DebugString() const override; 52 53 bool InitFromSerialized(const string& serialized, const int64 stamp_token); 54 55 string SerializeAsString() const; 56 57 int32 num_trees() const; 58 59 // Find the next node to which the example (specified by index_in_batch) 60 // traverses down from the current node indicated by tree_id and node_id. 61 // Args: 62 // tree_id: the index of the tree in the ensemble. 63 // node_id: the index of the node within the tree. 64 // index_in_batch: the index of the example within the batch (relevant to 65 // the index of the row to read in each bucketized_features). 66 // bucketized_features: vector of feature Vectors. 67 int32 next_node( 68 const int32 tree_id, const int32 node_id, const int32 index_in_batch, 69 const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const; 70 71 std::vector<float> node_value(const int32 tree_id, const int32 node_id) const; 72 73 void set_node_value(const int32 tree_id, const int32 node_id, 74 const float logits); 75 76 int32 GetNumLayersGrown(const int32 tree_id) const; 77 78 void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const; 79 80 void UpdateLastLayerNodesRange(const int32 node_range_start, 81 int32 node_range_end) const; 82 83 void GetLastLayerNodesRange(int32* node_range_start, 84 int32* node_range_end) const; 85 86 int64 GetNumNodes(const int32 tree_id); 87 88 void UpdateGrowingMetadata() const; 89 90 int32 GetNumLayersAttempted(); 91 92 bool is_leaf(const int32 tree_id, const int32 node_id) const; 93 94 int32 feature_id(const int32 tree_id, const int32 node_id) const; 95 96 int32 bucket_threshold(const int32 tree_id, const int32 node_id) const; 97 98 int32 left_id(const int32 tree_id, const int32 node_id) const; 99 100 int32 right_id(const int32 tree_id, const int32 node_id) const; 101 102 // Add a tree to the ensemble and returns a new tree_id. 103 int32 AddNewTree(const float weight); 104 105 // Adds new tree with one node to the ensemble and sets node's value to logits 106 int32 AddNewTreeWithLogits(const float weight, const float logits); 107 108 // Grows the tree by adding a split and leaves. 109 void AddBucketizedSplitNode(const int32 tree_id, const int32 node_id, 110 const int32 feature_id, const int32 threshold, 111 const float gain, const float left_contrib, 112 const float right_contrib, int32* left_node_id, 113 int32* right_node_id); 114 115 // Retrieves tree weights and returns as a vector. 116 // It involves a copy, so should be called only sparingly (like once per 117 // iteration, not per example). 118 std::vector<float> GetTreeWeights() const; 119 120 float GetTreeWeight(const int32 tree_id) const; 121 122 float IsTreeFinalized(const int32 tree_id) const; 123 124 float IsTreePostPruned(const int32 tree_id) const; 125 126 void SetIsFinalized(const int32 tree_id, const bool is_finalized); 127 128 // Sets the weight of i'th tree. 129 void SetTreeWeight(const int32 tree_id, const float weight); 130 131 // Resets the resource and frees the protos in arena. 132 // Caller needs to hold the mutex lock while calling this. 133 virtual void Reset(); 134 135 void PostPruneTree(const int32 current_tree); 136 137 // For a given node, returns the id in a pruned tree, as well as correction 138 // to the cached prediction that should be applied. If tree was not 139 // post-pruned, current_node_id will be equal to initial_node_id and logit 140 // update will be equal to zero. 141 void GetPostPruneCorrection(const int32 tree_id, const int32 initial_node_id, 142 int32* current_node_id, 143 float* logit_update) const; get_mutex()144 mutex* get_mutex() { return &mu_; } 145 146 private: 147 // Helper method to check whether a node is a terminal node in that it 148 // only has leaf nodes as children. 149 bool IsTerminalSplitNode(const int32 tree_id, const int32 node_id) const; 150 151 // For each pruned node, finds the leaf where it finally ended up and 152 // calculates the total update from that pruned node prediction. 153 void CalculateParentAndLogitUpdate( 154 const int32 start_node_id, 155 const std::vector<std::pair<int32, float>>& nodes_change, 156 int32* parent_id, float* change) const; 157 158 // Helper method to collect the information to be used to prune some nodes in 159 // the tree. 160 void RecursivelyDoPostPrunePreparation( 161 const int32 tree_id, const int32 node_id, 162 std::vector<int32>* nodes_to_delete, 163 std::vector<std::pair<int32, float>>* nodes_meta); 164 165 protected: 166 protobuf::Arena arena_; 167 mutex mu_; 168 boosted_trees::TreeEnsemble* tree_ensemble_; 169 }; 170 171 } // namespace tensorflow 172 173 #endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ 174