1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
17 
18 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
19 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
20 #include "tensorflow/core/framework/resource_mgr.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/protobuf.h"
23 
24 namespace tensorflow {
25 namespace boosted_trees {
26 namespace models {
27 
28 // Keep a tree ensemble in memory for efficient evaluation and mutation.
29 class DecisionTreeEnsembleResource : public StampedResource {
30  public:
31   // Constructor.
DecisionTreeEnsembleResource()32   explicit DecisionTreeEnsembleResource()
33       : decision_tree_ensemble_(
34             protobuf::Arena::CreateMessage<
35                 boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {}
36 
DebugString()37   string DebugString() const override {
38     return strings::StrCat("GTFlowDecisionTreeEnsemble[size=",
39                            decision_tree_ensemble_->trees_size(), "]");
40   }
41 
42   const boosted_trees::trees::DecisionTreeEnsembleConfig&
decision_tree_ensemble()43   decision_tree_ensemble() const {
44     return *decision_tree_ensemble_;
45   }
46 
num_trees()47   int32 num_trees() const { return decision_tree_ensemble_->trees_size(); }
48 
InitFromSerialized(const string & serialized,const int64 stamp_token)49   bool InitFromSerialized(const string& serialized, const int64 stamp_token) {
50     CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
51     if (ParseProtoUnlimited(decision_tree_ensemble_, serialized)) {
52       set_stamp(stamp_token);
53       return true;
54     }
55     return false;
56   }
57 
SerializeAsString()58   string SerializeAsString() const {
59     return decision_tree_ensemble_->SerializeAsString();
60   }
61 
62   // Increment num_layers_attempted and num_trees_attempted in growing_metadata
63   // if the tree is finalized.
IncrementAttempts()64   void IncrementAttempts() {
65     boosted_trees::trees::GrowingMetadata* const growing_metadata =
66         decision_tree_ensemble_->mutable_growing_metadata();
67     growing_metadata->set_num_layers_attempted(
68         growing_metadata->num_layers_attempted() + 1);
69     const int num_trees = decision_tree_ensemble_->trees_size();
70     if (num_trees <= 0 || LastTreeMetadata()->is_finalized()) {
71       growing_metadata->set_num_trees_attempted(
72           growing_metadata->num_trees_attempted() + 1);
73     }
74   }
75 
AddNewTree(const float weight)76   boosted_trees::trees::DecisionTreeConfig* AddNewTree(const float weight) {
77     // Adding a tree as well as a weight and a tree_metadata.
78     decision_tree_ensemble_->add_tree_weights(weight);
79     boosted_trees::trees::DecisionTreeMetadata* const metadata =
80         decision_tree_ensemble_->add_tree_metadata();
81     metadata->set_num_layers_grown(1);
82     return decision_tree_ensemble_->add_trees();
83   }
84 
RemoveLastTree()85   void RemoveLastTree() {
86     QCHECK_GT(decision_tree_ensemble_->trees_size(), 0);
87     decision_tree_ensemble_->mutable_trees()->RemoveLast();
88     decision_tree_ensemble_->mutable_tree_weights()->RemoveLast();
89     decision_tree_ensemble_->mutable_tree_metadata()->RemoveLast();
90   }
91 
LastTree()92   boosted_trees::trees::DecisionTreeConfig* LastTree() {
93     const int32 tree_size = decision_tree_ensemble_->trees_size();
94     QCHECK_GT(tree_size, 0);
95     return decision_tree_ensemble_->mutable_trees(tree_size - 1);
96   }
97 
LastTreeMetadata()98   boosted_trees::trees::DecisionTreeMetadata* LastTreeMetadata() {
99     const int32 metadata_size = decision_tree_ensemble_->tree_metadata_size();
100     QCHECK_GT(metadata_size, 0);
101     return decision_tree_ensemble_->mutable_tree_metadata(metadata_size - 1);
102   }
103 
104   // Retrieves tree weights and returns as a vector.
GetTreeWeights()105   std::vector<float> GetTreeWeights() const {
106     return {decision_tree_ensemble_->tree_weights().begin(),
107             decision_tree_ensemble_->tree_weights().end()};
108   }
109 
GetTreeWeight(const int32 index)110   float GetTreeWeight(const int32 index) const {
111     return decision_tree_ensemble_->tree_weights(index);
112   }
113 
MaybeAddUsedHandler(const int32 handler_id)114   void MaybeAddUsedHandler(const int32 handler_id) {
115     protobuf::RepeatedField<protobuf_int64>* used_ids =
116         decision_tree_ensemble_->mutable_growing_metadata()
117             ->mutable_used_handler_ids();
118     protobuf::RepeatedField<protobuf_int64>::iterator first =
119         std::lower_bound(used_ids->begin(), used_ids->end(), handler_id);
120     if (first == used_ids->end()) {
121       used_ids->Add(handler_id);
122       return;
123     }
124     if (handler_id == *first) {
125       // It is a duplicate entry.
126       return;
127     }
128     used_ids->Add(handler_id);
129     // Keep the list of used handlers sorted.
130     std::sort(used_ids->begin(), used_ids->end());
131   }
132 
GetUsedHandlers()133   std::vector<int64> GetUsedHandlers() const {
134     std::vector<int64> result;
135     result.reserve(
136         decision_tree_ensemble_->growing_metadata().used_handler_ids().size());
137     for (int64 h :
138          decision_tree_ensemble_->growing_metadata().used_handler_ids()) {
139       result.push_back(h);
140     }
141     return result;
142   }
143 
144   // Sets the weight of i'th tree, and increment num_updates in tree_metadata.
SetTreeWeight(const int32 index,const float weight,const int32 increment_num_updates)145   void SetTreeWeight(const int32 index, const float weight,
146                      const int32 increment_num_updates) {
147     QCHECK_GE(index, 0);
148     QCHECK_LT(index, num_trees());
149     decision_tree_ensemble_->set_tree_weights(index, weight);
150     if (increment_num_updates != 0) {
151       const int32 num_updates = decision_tree_ensemble_->tree_metadata(index)
152                                     .num_tree_weight_updates();
153       decision_tree_ensemble_->mutable_tree_metadata(index)
154           ->set_num_tree_weight_updates(num_updates + increment_num_updates);
155     }
156   }
157 
158   // Resets the resource and frees the protos in arena.
159   // Caller needs to hold the mutex lock while calling this.
Reset()160   virtual void Reset() {
161     // Reset stamp.
162     set_stamp(-1);
163 
164     // Clear tree ensemle.
165     arena_.Reset();
166     CHECK_EQ(0, arena_.SpaceAllocated());
167     decision_tree_ensemble_ = protobuf::Arena::CreateMessage<
168         boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_);
169   }
170 
get_mutex()171   mutex* get_mutex() { return &mu_; }
172 
173  protected:
174   protobuf::Arena arena_;
175   mutex mu_;
176   boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_;
177 };
178 
179 }  // namespace models
180 }  // namespace boosted_trees
181 }  // namespace tensorflow
182 
183 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
184