1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
16 
17 #include <iterator>
18 #include <numeric>
19 #include <unordered_set>
20 
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/random/philox_random.h"
23 #include "tensorflow/core/lib/random/simple_philox.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 using tensorflow::Status;
27 using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
28 using tensorflow::random::PhiloxRandom;
29 using tensorflow::random::SimplePhilox;
30 
31 namespace tensorflow {
32 namespace boosted_trees {
33 namespace utils {
34 
DropOutTrees(const uint64 seed,const LearningRateDropoutDrivenConfig & config,const std::unordered_set<int32> & trees_not_to_drop,const std::vector<float> & weights,std::vector<int32> * dropped_trees,std::vector<float> * original_weights)35 Status DropoutUtils::DropOutTrees(
36     const uint64 seed, const LearningRateDropoutDrivenConfig& config,
37     const std::unordered_set<int32>& trees_not_to_drop,
38     const std::vector<float>& weights, std::vector<int32>* dropped_trees,
39     std::vector<float>* original_weights) {
40   // Verify params.
41   if (dropped_trees == nullptr) {
42     return errors::Internal("Dropped trees is nullptr.");
43   }
44   if (original_weights == nullptr) {
45     return errors::InvalidArgument("Original weights is nullptr.");
46   }
47   const float dropout_probability = config.dropout_probability();
48   if (dropout_probability < 0 || dropout_probability > 1) {
49     return errors::InvalidArgument(
50         "Dropout probability must be in [0,1] range");
51   }
52   const float probability_of_skipping_dropout =
53       config.probability_of_skipping_dropout();
54   if (probability_of_skipping_dropout < 0 ||
55       probability_of_skipping_dropout > 1) {
56     return errors::InvalidArgument(
57         "Probability of skipping dropout must be in [0,1] range");
58   }
59   const auto num_trees = weights.size();
60 
61   dropped_trees->clear();
62   original_weights->clear();
63 
64   // If dropout is no op, return.
65   if (dropout_probability == 0 || probability_of_skipping_dropout == 1.0) {
66     return Status::OK();
67   }
68 
69   // Roll the dice for each tree.
70   PhiloxRandom philox(seed);
71   SimplePhilox rng(&philox);
72 
73   std::vector<int32> trees_to_keep;
74 
75   // What is the probability of skipping dropout altogether.
76   if (probability_of_skipping_dropout != 0) {
77     // First roll the dice - do we do dropout
78     double roll = rng.RandDouble();
79     if (roll < probability_of_skipping_dropout) {
80       // don't do dropout
81       return Status::OK();
82     }
83   }
84 
85   for (int32 i = 0; i < num_trees; ++i) {
86     // We can't drop some of the trees: for example, bias tree in batch mode,
87     // or current tree that is built, in the batch mode.
88     if (trees_not_to_drop.find(i) != trees_not_to_drop.end()) {
89       continue;
90     }
91     double roll = rng.RandDouble();
92     if (roll >= dropout_probability) {
93       trees_to_keep.push_back(i);
94     } else {
95       dropped_trees->push_back(i);
96     }
97   }
98 
99   // Sort the dropped trees indices.
100   std::sort(dropped_trees->begin(), dropped_trees->end());
101   for (const int32 dropped_tree : *dropped_trees) {
102     original_weights->push_back(weights[dropped_tree]);
103   }
104 
105   return Status::OK();
106 }
107 
GetTreesWeightsForAddingTrees(const std::vector<int32> & dropped_trees,const std::vector<float> & dropped_trees_original_weights,const int32 new_trees_first_index,const int32 num_trees_to_add,std::vector<float> * current_weights,std::vector<int32> * num_updates)108 void DropoutUtils::GetTreesWeightsForAddingTrees(
109     const std::vector<int32>& dropped_trees,
110     const std::vector<float>& dropped_trees_original_weights,
111     const int32 new_trees_first_index, const int32 num_trees_to_add,
112     std::vector<float>* current_weights, std::vector<int32>* num_updates) {
113   CHECK(num_updates->size() == current_weights->size());
114   // combined weight of trees that were dropped out
115 
116   const float dropped_sum =
117       std::accumulate(dropped_trees_original_weights.begin(),
118                       dropped_trees_original_weights.end(), 0.0);
119 
120   const int num_dropped = dropped_trees.size();
121 
122   // Allocate additional weight for the new tree
123   const float total_new_trees_weight = dropped_sum / (num_dropped + 1);
124 
125   for (int i = 0; i < num_trees_to_add; ++i) {
126     const int32 new_tree_index = new_trees_first_index + i;
127     if (new_tree_index < current_weights->size()) {
128       // We have the entries in weights and updates for this tree already
129       (*current_weights)[new_tree_index] =
130           total_new_trees_weight / num_trees_to_add;
131       (*num_updates)[new_tree_index]++;
132     } else {
133       // We need to add a new entry. This is non-batch mode.
134       current_weights->push_back(total_new_trees_weight / num_trees_to_add);
135       num_updates->push_back(1);
136     }
137   }
138 
139   for (int32 i = 0; i < dropped_trees.size(); ++i) {
140     const int32 dropped = dropped_trees[i];
141     const float original_weight = dropped_trees_original_weights[i];
142     const float new_weight = original_weight * num_dropped / (num_dropped + 1);
143     (*current_weights)[dropped] = new_weight;
144     // Update the number of updates per tree.
145     ++(*num_updates)[dropped];
146   }
147 }
148 
149 }  // namespace utils
150 }  // namespace boosted_trees
151 }  // namespace tensorflow
152