1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h"
16 
17 #include <cfloat>
18 #include <queue>
19 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
21 #include "tensorflow/core/lib/random/distribution_sampler.h"
22 #include "tensorflow/core/lib/random/random.h"
23 
24 namespace tensorflow {
25 namespace tensorforest {
26 
27 // When creating evaluators for the split candidates, use these
28 // for the left and right return values.
29 static const int32 LEFT_INDEX = 0;
30 static const int32 RIGHT_INDEX = 1;
31 
GrowStats(const TensorForestParams & params,int32 depth)32 GrowStats::GrowStats(const TensorForestParams& params, int32 depth)
33     : weight_sum_(0),
34       depth_(depth),
35       params_(params),
36       split_after_samples_(ResolveParam(params.split_after_samples(), depth)),
37       num_splits_to_consider_(
38           ResolveParam(params.num_splits_to_consider(), depth)),
39       num_outputs_(params.num_outputs()) {}
40 
AddSplit(const decision_trees::BinaryNode & split,const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)41 void GrowStats::AddSplit(const decision_trees::BinaryNode& split,
42                          const std::unique_ptr<TensorDataSet>& input_data,
43                          const InputTarget* target, int example) {
44   // It's possible that the split collection calls AddSplit, but we actually
45   // have all the splits we need and are just waiting for them to be fully
46   // initialized.
47   if (splits_.size() < num_splits_to_consider_) {
48     splits_.push_back(split);
49     evaluators_.emplace_back(
50         CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
51     AddSplitStats(target, example);
52   }
53 
54   if (input_data != nullptr && target != nullptr &&
55       params_.initialize_average_splits()) {
56     AdditionalInitializationExample(input_data, target, example);
57   }
58 }
59 
RemoveSplit(int split_num)60 void GrowStats::RemoveSplit(int split_num) {
61   splits_.erase(splits_.begin() + split_num);
62   evaluators_.erase(evaluators_.begin() + split_num);
63   RemoveSplitStats(split_num);
64 }
65 
66 // ------------------------ Classification --------------------------- //
67 
ClassificationStats(const TensorForestParams & params,int32 depth)68 ClassificationStats::ClassificationStats(const TensorForestParams& params,
69                                          int32 depth)
70     : GrowStats(params, depth), finish_early_(false) {
71   // Early splitting params.
72   if (params.finish_type().type() == SPLIT_FINISH_BASIC) {
73     min_split_samples_ = split_after_samples_;
74     finish_sample_epoch_ = 1;
75     finish_check_every_ = split_after_samples_ * 2;
76   } else {
77     if (!params.has_dominate_fraction() || !params.has_min_split_samples()) {
78       LOG(FATAL) << "dominate_fraction and min_split_samples "
79                  << "required for early-finish strategy.";
80     } else {
81       min_split_samples_ = ResolveParam(params.min_split_samples(), depth);
82       finish_check_every_ =
83           ResolveParam(params.finish_type().check_every_steps(), depth);
84       finish_sample_epoch_ = min_split_samples_ / finish_check_every_;
85 
86       dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
87       if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) {
88         LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_;
89       }
90     }
91   }
92 
93   // Pruning params.
94   if (params.pruning_type().type() != SPLIT_PRUNE_NONE) {
95     prune_check_every_ =
96         ResolveParam(params.pruning_type().prune_every_samples(), depth);
97     prune_sample_epoch_ = 1;
98     prune_fraction_ = 0.0;
99     switch (params_.pruning_type().type()) {
100       case SPLIT_PRUNE_HALF:
101         prune_fraction_ = 0.5;
102         break;
103       case SPLIT_PRUNE_QUARTER:
104         prune_fraction_ = 0.25;
105         break;
106       case SPLIT_PRUNE_10_PERCENT:
107         prune_fraction_ = 0.10;
108         break;
109       case SPLIT_PRUNE_HOEFFDING:
110         dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
111         half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_));
112         break;
113       default:
114         LOG(WARNING) << "Unknown pruning type";
115     }
116   } else {
117     prune_check_every_ = split_after_samples_ * 2;
118     prune_sample_epoch_ = 1;
119   }
120 
121   if (params.use_running_stats_method()) {
122     left_gini_.reset(new RunningGiniScores());
123     right_gini_.reset(new RunningGiniScores());
124   }
125 
126   single_rand_ = std::unique_ptr<random::PhiloxRandom>(
127       new random::PhiloxRandom(random::New64()));
128   rng_ = std::unique_ptr<random::SimplePhilox>(
129       new random::SimplePhilox(single_rand_.get()));
130 }
131 
AdditionalInitializationExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)132 void ClassificationStats::AdditionalInitializationExample(
133     const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
134     int example) {
135   const int32 new_target = target->GetTargetAsClassIndex(example, 0);
136   std::unordered_set<int> to_erase;
137   for (auto it = half_initialized_splits_.begin();
138        it != half_initialized_splits_.end(); ++it) {
139     if (it->second != new_target) {
140       auto& split = splits_[it->first];
141       if (split.has_inequality_left_child_test()) {
142         auto& test = split.inequality_left_child_test();
143         auto* thresh =
144             split.mutable_inequality_left_child_test()->mutable_threshold();
145         if (test.has_feature_id()) {
146           const float val =
147               input_data->GetExampleValue(example, test.feature_id());
148           thresh->set_float_value((thresh->float_value() + val) / 2);
149         }
150       }
151       to_erase.insert(it->first);
152     }
153   }
154 
155   for (const int split_id : to_erase) {
156     half_initialized_splits_.erase(split_id);
157   }
158 }
159 
IsFinished() const160 bool ClassificationStats::IsFinished() const {
161   bool basic = (weight_sum_ >= split_after_samples_) && !is_pure();
162   return basic || finish_early_;
163 }
164 
MaybeCachedGiniScore(int split,float * left_sum,float * right_sum) const165 float ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum,
166                                                 float* right_sum) const {
167   if (left_gini_ == nullptr) {
168     return GiniScore(split, left_sum, right_sum);
169   } else {
170     *left_sum = left_gini_->sum(split);
171     const float left = WeightedSmoothedGini(
172         *left_sum, left_gini_->square(split), num_outputs_);
173 
174     *right_sum = right_gini_->sum(split);
175     const float right = WeightedSmoothedGini(
176         *right_sum, right_gini_->square(split), num_outputs_);
177 
178     return left + right;
179   }
180 }
181 
AddExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)182 void ClassificationStats::AddExample(
183     const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
184     int example) {
185   const int64 int_label = target->GetTargetAsClassIndex(example, 0);
186   const float weight = target->GetTargetWeight(example);
187 
188   for (int i = 0; i < num_splits(); ++i) {
189     auto& eval = evaluators_[i];
190     if (eval->Decide(input_data, example) == LEFT_INDEX) {
191       if (left_gini_ != nullptr) {
192         left_gini_->update(i, left_count(i, int_label), weight);
193       }
194       ClassificationAddLeftExample(i, int_label, weight);
195     } else {
196       if (right_gini_ != nullptr) {
197         right_gini_->update(i, right_count(i, int_label), weight);
198       }
199       ClassificationAddRightExample(i, int_label, weight);
200     }
201   }
202 
203   ClassificationAddTotalExample(int_label, weight);
204 
205   weight_sum_ += weight;
206 
207   CheckFinishEarly();
208   CheckPrune();
209 }
210 
CheckPrune()211 void ClassificationStats::CheckPrune() {
212   if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() ||
213       weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
214     return;
215   }
216   ++prune_sample_epoch_;
217 
218   if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) {
219     CheckPruneHoeffding();
220     return;
221   }
222 
223   const int to_remove = num_splits() * prune_fraction_;
224   if (to_remove <= 0) {
225     return;
226   }
227 
228   // pair ordering is first-then-second by default, no need for custom
229   // comparison.  Use std::greater to make it a min-heap.
230   std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
231                       std::greater<std::pair<float, int>>>
232       worst;
233 
234   // Track indices that are in the heap so we can iterate over them
235   // by largest-first later.
236   std::set<int> indices;
237 
238   for (int i = 0; i < num_splits(); ++i) {
239     float left, right;
240     const float split_score = MaybeCachedGiniScore(i, &left, &right);
241     if (worst.size() < to_remove) {
242       worst.push(std::pair<float, int>(split_score, i));
243       indices.insert(i);
244     } else if (worst.top().first < split_score) {
245       indices.erase(worst.top().second);
246       worst.pop();
247       worst.push(std::pair<float, int>(split_score, i));
248       indices.insert(i);
249     }
250   }
251 
252   // traverse indices from the back so that they are removed correctly.
253   for (auto it = indices.rbegin(); it != indices.rend(); ++it) {
254     RemoveSplit(*it);
255   }
256 }
257 
CheckPruneHoeffding()258 void ClassificationStats::CheckPruneHoeffding() {
259   std::vector<float> split_scores(num_splits());
260   // Find best split score
261   float best_split_score = FLT_MAX;
262   for (int i = 0; i < num_splits(); ++i) {
263     float left, right;
264     split_scores[i] = MaybeCachedGiniScore(i, &left, &right);
265     if (split_scores[i] < best_split_score) {
266       best_split_score = split_scores[i];
267     }
268   }
269 
270   // We apply the Hoeffding bound to the difference between the best split
271   // score and the i-th split score.
272   // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted.
273   const float num_classes = params_.num_outputs();
274   const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes);
275   float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_);
276   for (int i = num_splits() - 1; i >= 0; i--) {
277     if (split_scores[i] - best_split_score > epsilon) {
278       RemoveSplit(i);
279     }
280   }
281 }
282 
CheckFinishEarly()283 void ClassificationStats::CheckFinishEarly() {
284   if (weight_sum_ < min_split_samples_ ||
285       weight_sum_ < finish_sample_epoch_ * finish_check_every_) {
286     return;
287   }
288   ++finish_sample_epoch_;
289 
290   if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) {
291     CheckFinishEarlyHoeffding();
292   } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) {
293     CheckFinishEarlyBootstrap();
294   }
295 }
296 
CheckFinishEarlyHoeffding()297 void ClassificationStats::CheckFinishEarlyHoeffding() {
298   // Each term in the Gini impurity can range from 0 to 0.5 * 0.5.
299   float range = 0.25 * static_cast<float>(params_.num_outputs()) * weight_sum_;
300 
301   float hoeffding_bound =
302       range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_));
303 
304   float unused_left_sum, unused_right_sum;
305   std::function<float(int)> score_fn =
306       std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
307                 std::placeholders::_1, &unused_left_sum, &unused_right_sum);
308 
309   float best_score;
310   int32 best_index;
311   float second_best_score;
312   int32 second_best_index;
313   GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
314              &second_best_score, &second_best_index);
315 
316   finish_early_ = (second_best_score - best_score) > hoeffding_bound;
317 }
318 
MakeBootstrapWeights(int index,std::vector<float> * weights)319 void ClassificationStats::MakeBootstrapWeights(int index,
320                                                std::vector<float>* weights) {
321   int n = weight_sum_;
322   float denom = static_cast<float>(n) + static_cast<float>(num_outputs_);
323   for (int i = 0; i < num_outputs_; ++i) {
324     // Use the Laplace smoothed per-class probabilities when generating the
325     // bootstrap samples.
326     (*weights)[i] = (left_count(index, i) + 1.0) / denom;
327     (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom;
328   }
329 }
330 
NumBootstrapSamples() const331 int ClassificationStats::NumBootstrapSamples() const {
332   float p = 1.0 - dominate_fraction_;
333   int bootstrap_samples = 1;
334   while (p < 1.0) {
335     ++bootstrap_samples;
336     p = p * 2;
337   }
338   return bootstrap_samples;
339 }
340 
CheckFinishEarlyBootstrap()341 void ClassificationStats::CheckFinishEarlyBootstrap() {
342   float unused_left_sum, unused_right_sum;
343   std::function<float(int)> score_fn =
344       std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
345                 std::placeholders::_1, &unused_left_sum, &unused_right_sum);
346 
347   float best_score;
348   int32 best_index;
349   float second_best_score;
350   int32 second_best_index;
351   GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
352              &second_best_score, &second_best_index);
353 
354   std::vector<float> weights1(num_outputs_ * 2);
355   MakeBootstrapWeights(best_index, &weights1);
356   random::DistributionSampler ds1(weights1);
357 
358   std::vector<float> weights2(num_outputs_ * 2);
359   MakeBootstrapWeights(second_best_index, &weights2);
360   random::DistributionSampler ds2(weights2);
361 
362   const int bootstrap_samples = NumBootstrapSamples();
363 
364   int worst_g1 = 0;
365   for (int i = 0; i < bootstrap_samples; i++) {
366     int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get());
367     worst_g1 = std::max(worst_g1, g1);
368   }
369 
370   int best_g2 = 99;
371   for (int i = 0; i < bootstrap_samples; i++) {
372     int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get());
373     best_g2 = std::min(best_g2, g2);
374   }
375 
376   finish_early_ = worst_g1 < best_g2;
377 }
378 
BestSplit(SplitCandidate * best) const379 bool ClassificationStats::BestSplit(SplitCandidate* best) const {
380   float min_score = FLT_MAX;
381   int best_index = -1;
382   float best_left_sum, best_right_sum;
383 
384   // Calculate sums.
385   for (int i = 0; i < num_splits(); ++i) {
386     float left_sum, right_sum;
387     const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
388     // Find the lowest gini.
389     if (left_sum > 0 && right_sum > 0 &&
390         split_score < min_score) {  // useless check
391       min_score = split_score;
392       best_index = i;
393       best_left_sum = left_sum;
394       best_right_sum = right_sum;
395     }
396   }
397 
398   // This could happen if all the splits are useless.
399   if (best_index < 0) {
400     return false;
401   }
402 
403   // Fill in stats to be used for leaf model.
404   *best->mutable_split() = splits_[best_index];
405   auto* left = best->mutable_left_stats();
406   left->set_weight_sum(best_left_sum);
407   auto* right = best->mutable_right_stats();
408   right->set_weight_sum(best_right_sum);
409   InitLeafClassStats(best_index, left, right);
410 
411   return true;
412 }
413 
414 // ------------------------ Dense Classification --------------------------- //
ExtractFromProto(const FertileSlot & slot)415 void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
416   Initialize();
417   if (!slot.has_post_init_leaf_stats()) {
418     return;
419   }
420   const int32 num_classes = params_.num_outputs();
421   weight_sum_ = slot.post_init_leaf_stats().weight_sum();
422   const auto& class_stats =
423       slot.post_init_leaf_stats().classification().dense_counts();
424 
425   // Total counts.
426   for (int i = 0; i < num_classes; ++i) {
427     total_counts_[i] = class_stats.value(i).float_value();
428     num_outputs_seen_ += total_counts_[i] != 0;
429   }
430 
431   // Candidate counts and splits.
432   int split_num = 0;
433   for (const auto& cand : slot.candidates()) {
434     AddSplit(cand.split(), nullptr, nullptr, -1);
435     const auto& left_stats = cand.left_stats().classification().dense_counts();
436     for (int i = 0; i < num_classes; ++i) {
437       const float val = left_stats.value(i).float_value();
438       mutable_left_count(split_num, i) = val;
439       MaybeInitializeRunningCount(split_num, val);
440     }
441     ++split_num;
442   }
443 }
444 
PackToProto(FertileSlot * slot) const445 void DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
446   auto* slot_stats = slot->mutable_post_init_leaf_stats();
447   slot_stats->set_weight_sum(weight_sum_);
448 
449   auto* class_stats = slot->mutable_post_init_leaf_stats()
450                           ->mutable_classification()
451                           ->mutable_dense_counts();
452   for (int i = 0; i < num_outputs_; ++i) {
453     class_stats->add_value()->set_float_value(total_counts_[i]);
454   }
455 
456   for (int split_num = 0; split_num < num_splits(); ++split_num) {
457     auto* cand = slot->add_candidates();
458     *cand->mutable_split() = splits_[split_num];
459     auto* left_stats = cand->mutable_left_stats()
460                            ->mutable_classification()
461                            ->mutable_dense_counts();
462     for (int i = 0; i < num_outputs_; ++i) {
463       left_stats->add_value()->set_float_value(left_count(split_num, i));
464     }
465   }
466 }
467 
GiniScore(int split,float * left_sum,float * right_sum) const468 float DenseClassificationGrowStats::GiniScore(int split, float* left_sum,
469                                               float* right_sum) const {
470   float left_square = 0, right_square = 0;
471   *left_sum = 0;
472   *right_sum = 0;
473   for (int j = 0; j < num_outputs_; ++j) {
474     const float left = left_count(split, j);
475     *left_sum += left;
476     left_square += left * left;
477     const float right = right_count(split, j);
478     *right_sum += right;
479     right_square += right * right;
480   }
481 
482   const float left_score =
483       WeightedSmoothedGini(*left_sum, left_square, num_outputs_);
484   const float right_score =
485       WeightedSmoothedGini(*right_sum, right_square, num_outputs_);
486   return left_score + right_score;
487 }
488 
InitLeafClassStats(int best_split_index,LeafStat * left_stats,LeafStat * right_stats) const489 void DenseClassificationGrowStats::InitLeafClassStats(
490     int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
491   auto* left_class_stats = left_stats->mutable_classification();
492   auto* left_counts = left_class_stats->mutable_dense_counts();
493   for (int i = 0; i < params_.num_outputs(); ++i) {
494     left_counts->add_value()->set_float_value(left_count(best_split_index, i));
495   }
496 
497   auto* right_class_stats = right_stats->mutable_classification();
498   auto* right_counts = right_class_stats->mutable_dense_counts();
499   for (int i = 0; i < params_.num_outputs(); ++i) {
500     right_counts->add_value()->set_float_value(total_counts_[i] -
501                                                left_count(best_split_index, i));
502   }
503 }
504 
505 // ------------------------ Sparse Classification --------------------------- //
ExtractFromProto(const FertileSlot & slot)506 void SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
507   Initialize();
508   if (!slot.has_post_init_leaf_stats()) {
509     return;
510   }
511   weight_sum_ = slot.post_init_leaf_stats().weight_sum();
512   const auto& class_stats =
513       slot.post_init_leaf_stats().classification().sparse_counts();
514 
515   // Total counts.
516   for (auto const& entry : class_stats.sparse_value()) {
517     total_counts_[entry.first] = entry.second.float_value();
518   }
519 
520   // Candidate counts and splits.
521   int split_num = 0;
522   for (const auto& cand : slot.candidates()) {
523     AddSplit(cand.split(), nullptr, nullptr, -1);
524     const auto& left_stats = cand.left_stats().classification().sparse_counts();
525     for (auto const& entry : left_stats.sparse_value()) {
526       const float val = entry.second.float_value();
527       left_counts_[split_num][entry.first] = val;
528       MaybeInitializeRunningCount(split_num, val);
529     }
530     ++split_num;
531   }
532 }
533 
PackToProto(FertileSlot * slot) const534 void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
535   auto* slot_stats = slot->mutable_post_init_leaf_stats();
536   slot_stats->set_weight_sum(weight_sum_);
537 
538   auto* class_stats = slot->mutable_post_init_leaf_stats()
539                           ->mutable_classification()
540                           ->mutable_sparse_counts()
541                           ->mutable_sparse_value();
542   for (const auto& entry : total_counts_) {
543     decision_trees::Value val;
544     val.set_float_value(entry.second);
545     (*class_stats)[entry.first] = val;
546   }
547 
548   for (int split_num = 0; split_num < num_splits(); ++split_num) {
549     auto* cand = slot->add_candidates();
550     *cand->mutable_split() = splits_[split_num];
551     auto* left_stats = cand->mutable_left_stats()
552                            ->mutable_classification()
553                            ->mutable_sparse_counts()
554                            ->mutable_sparse_value();
555     for (const auto& entry : left_counts_[split_num]) {
556       decision_trees::Value val;
557       val.set_float_value(entry.second);
558       (*left_stats)[entry.first] = val;
559     }
560   }
561 }
562 
GiniScore(int split,float * left_sum,float * right_sum) const563 float SparseClassificationGrowStats::GiniScore(int split, float* left_sum,
564                                                float* right_sum) const {
565   float left_square = 0, right_square = 0;
566   *left_sum = 0;
567   *right_sum = 0;
568   for (const auto& entry : total_counts_) {
569     const int label = entry.first;
570     float left = 0;
571     float right = 0;
572     auto it = left_counts_[split].find(label);
573     if (it == left_counts_[split].end()) {
574       right = entry.second;
575     } else {
576       left = it->second;
577       right = entry.second - it->second;
578     }
579     *left_sum += left;
580     left_square += left * left;
581     *right_sum += right;
582     right_square += right * right;
583   }
584   const int32 num_classes = params_.num_outputs();
585   const float left_score =
586       WeightedSmoothedGini(*left_sum, left_square, num_classes);
587   const float right_score =
588       WeightedSmoothedGini(*right_sum, right_square, num_classes);
589   return left_score + right_score;
590 }
591 
InitLeafClassStats(int best_split_index,LeafStat * left_stats,LeafStat * right_stats) const592 void SparseClassificationGrowStats::InitLeafClassStats(
593     int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
594   auto* left_class_stats = left_stats->mutable_classification();
595   auto* left_counts =
596       left_class_stats->mutable_sparse_counts()->mutable_sparse_value();
597   auto* right_class_stats = right_stats->mutable_classification();
598   auto* right_counts =
599       right_class_stats->mutable_sparse_counts()->mutable_sparse_value();
600 
601   for (const auto& entry : total_counts_) {
602     auto it = left_counts_[best_split_index].find(entry.first);
603     if (it == left_counts_[best_split_index].end()) {
604       (*right_counts)[entry.first].set_float_value(entry.second);
605     } else {
606       const float left = it->second;
607       const float right = entry.second - it->second;
608       (*left_counts)[entry.first].set_float_value(left);
609       if (right > 0) {
610         (*right_counts)[entry.first].set_float_value(right);
611       }
612     }
613   }
614 }
615 
616 // -------------------- FixedSizeClassStats --------------------------------- //
617 
618 // FixedSizeClassStats implements the "SpaceSaving" algorithm by
619 // Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi.  See for example
620 // https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf
621 
argmin(const std::unordered_map<int,float> & m)622 int argmin(const std::unordered_map<int, float>& m) {
623   int c = -1;
624   float f = FLT_MAX;
625   for (const auto it : m) {
626     if (it.second < f) {
627       f = it.second;
628       c = it.first;
629     }
630   }
631   return c;
632 }
633 
accumulate(int c,float w)634 void FixedSizeClassStats::accumulate(int c, float w) {
635   auto it = class_weights_.find(c);
636   if (it != class_weights_.end()) {
637     it->second += w;
638     if (c == smallest_weight_class_) {
639       smallest_weight_class_ = argmin(class_weights_);
640     }
641     return;
642   }
643 
644   if (class_weights_.size() < n_) {
645     class_weights_.insert(it, std::pair<int, float>(c, w));
646     if (class_weights_.size() == n_) {
647       // Can't assume last added has the smallest weight, because the
648       // w's might be all different.
649       smallest_weight_class_ = argmin(class_weights_);
650     }
651     return;
652   }
653 
654   // This is the slightly unintuitive heart of the SpaceSaving algorithm:
655   // if the map is full and we see a new class, we find the entry with the
656   // smallest weight and "take it over":  we add our weight to its weight,
657   // and assign it all to the new seen class.
658   it = class_weights_.find(smallest_weight_class_);
659   float new_weight = it->second + w;
660   class_weights_.erase(it);
661   class_weights_[c] = new_weight;
662   smallest_weight_class_ = argmin(class_weights_);
663 }
664 
get_weight(int c) const665 float FixedSizeClassStats::get_weight(int c) const {
666   // Every entry in class_weights_ might be overstated by as much as the
667   // smallest_weight.  We therefore assume that each has been overstated
668   // by smallest_weight / 2.0, and we re-distribute that mass over all
669   // num_classes_ classes.
670   float smallest_weight = 0.0;
671   auto it = class_weights_.find(smallest_weight_class_);
672   if (it != class_weights_.end()) {
673     smallest_weight = it->second;
674   }
675   float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
676   it = class_weights_.find(c);
677   if (it != class_weights_.end()) {
678     w += it->second - smallest_weight / 2.0;
679   }
680   return w;
681 }
682 
set_sum_and_square(float * sum,float * square) const683 void FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const {
684   *sum = 0.0;
685   *square = 0.0;
686 
687   float smallest_weight = 0.0;
688   auto it = class_weights_.find(smallest_weight_class_);
689   if (it != class_weights_.end()) {
690     smallest_weight = it->second;
691   }
692 
693   float w;
694   for (const auto it : class_weights_) {
695     *sum += it.second;
696     w = get_weight(it.first);
697     *square += w * w;
698   }
699 
700   w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
701   *square += (num_classes_ - n_) * w * w;
702 }
703 
ExtractFromProto(const decision_trees::SparseVector & sparse_vector)704 void FixedSizeClassStats::ExtractFromProto(
705     const decision_trees::SparseVector& sparse_vector) {
706   for (const auto& it : sparse_vector.sparse_value()) {
707     class_weights_[it.first] = it.second.float_value();
708   }
709   if (class_weights_.size() == n_) {
710     smallest_weight_class_ = argmin(class_weights_);
711   }
712 }
713 
PackToProto(decision_trees::SparseVector * sparse_vector) const714 void FixedSizeClassStats::PackToProto(
715     decision_trees::SparseVector* sparse_vector) const {
716   for (const auto it : class_weights_) {
717     (*sparse_vector->mutable_sparse_value())[it.first].set_float_value(
718         it.second);
719   }
720 }
721 
722 // --------------------- FixedSizeSparseClassificationGrowStats ------------- //
723 
ExtractFromProto(const FertileSlot & slot)724 void FixedSizeSparseClassificationGrowStats::ExtractFromProto(
725     const FertileSlot& slot) {
726   Initialize();
727   if (!slot.has_post_init_leaf_stats()) {
728     return;
729   }
730   weight_sum_ = slot.post_init_leaf_stats().weight_sum();
731 
732   // Candidate counts and splits.
733   int split_num = 0;
734   left_counts_.clear();
735   right_counts_.clear();
736   for (const auto& cand : slot.candidates()) {
737     AddSplit(cand.split(), nullptr, nullptr, -1);
738     const auto& left_stats = cand.left_stats().classification().sparse_counts();
739     left_counts_.emplace_back(params_.num_classes_to_track(),
740                               params_.num_outputs());
741     left_counts_[split_num].ExtractFromProto(left_stats);
742     const auto& right_stats =
743         cand.right_stats().classification().sparse_counts();
744     right_counts_.emplace_back(params_.num_classes_to_track(),
745                                params_.num_outputs());
746     right_counts_[split_num].ExtractFromProto(right_stats);
747     ++split_num;
748   }
749 }
750 
PackToProto(FertileSlot * slot) const751 void FixedSizeSparseClassificationGrowStats::PackToProto(
752     FertileSlot* slot) const {
753   auto* slot_stats = slot->mutable_post_init_leaf_stats();
754   slot_stats->set_weight_sum(weight_sum_);
755 
756   for (int split_num = 0; split_num < num_splits(); ++split_num) {
757     auto* cand = slot->add_candidates();
758     *cand->mutable_split() = splits_[split_num];
759     auto* left_stats = cand->mutable_left_stats()
760                            ->mutable_classification()
761                            ->mutable_sparse_counts();
762     left_counts_[split_num].PackToProto(left_stats);
763     auto* right_stats = cand->mutable_right_stats()
764                             ->mutable_classification()
765                             ->mutable_sparse_counts();
766     right_counts_[split_num].PackToProto(right_stats);
767   }
768 }
769 
GiniScore(int split,float * left_sum,float * right_sum) const770 float FixedSizeSparseClassificationGrowStats::GiniScore(
771     int split, float* left_sum, float* right_sum) const {
772   float left_square, right_square;
773   left_counts_[split].set_sum_and_square(left_sum, &left_square);
774   right_counts_[split].set_sum_and_square(right_sum, &right_square);
775   const int32 num_classes = params_.num_outputs();
776   const float left_score =
777       WeightedSmoothedGini(*left_sum, left_square, num_classes);
778   const float right_score =
779       WeightedSmoothedGini(*right_sum, right_square, num_classes);
780   return left_score + right_score;
781 }
782 
InitLeafClassStats(int best_split_index,LeafStat * left_stats,LeafStat * right_stats) const783 void FixedSizeSparseClassificationGrowStats::InitLeafClassStats(
784     int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
785   auto* left_class_stats = left_stats->mutable_classification();
786   auto* left_counts = left_class_stats->mutable_sparse_counts();
787   left_counts_[best_split_index].PackToProto(left_counts);
788 
789   auto* right_class_stats = right_stats->mutable_classification();
790   auto* right_counts = right_class_stats->mutable_sparse_counts();
791   right_counts_[best_split_index].PackToProto(right_counts);
792 }
793 
794 // --------------------- Least Squares Regression --------------------------- //
ExtractFromProto(const FertileSlot & slot)795 void LeastSquaresRegressionGrowStats::ExtractFromProto(
796     const FertileSlot& slot) {
797   const int32 num_outputs = params_.num_outputs();
798   Initialize();
799   if (!slot.has_post_init_leaf_stats()) {
800     return;
801   }
802   weight_sum_ = slot.post_init_leaf_stats().weight_sum();
803   const auto& total_sums =
804       slot.post_init_leaf_stats().regression().mean_output();
805   const auto& total_squares =
806       slot.post_init_leaf_stats().regression().mean_output_squares();
807 
808   // Total counts.
809   for (int i = 0; i < num_outputs; ++i) {
810     total_sum_[i] = total_sums.value(i).float_value();
811     total_sum_squares_[i] = total_squares.value(i).float_value();
812   }
813 
814   // Candidate counts and splits.
815   int split_num = 0;
816   for (const auto& cand : slot.candidates()) {
817     AddSplit(cand.split(), nullptr, nullptr, -1);
818     const auto& sums = cand.left_stats().regression().mean_output();
819     const auto& squares = cand.left_stats().regression().mean_output_squares();
820     for (int i = 0; i < num_outputs; ++i) {
821       left_sum(split_num, i) = sums.value(i).float_value();
822       left_square(split_num, i) = squares.value(i).float_value();
823     }
824     left_counts_[split_num] = cand.left_stats().weight_sum();
825     ++split_num;
826   }
827 }
828 
PackToProto(FertileSlot * slot) const829 void LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const {
830   const int32 num_outputs = params_.num_outputs();
831   auto* slot_stats = slot->mutable_post_init_leaf_stats();
832   slot_stats->set_weight_sum(weight_sum_);
833 
834   auto* total_sums = slot->mutable_post_init_leaf_stats()
835                          ->mutable_regression()
836                          ->mutable_mean_output();
837   auto* total_squares = slot->mutable_post_init_leaf_stats()
838                             ->mutable_regression()
839                             ->mutable_mean_output_squares();
840 
841   for (int i = 0; i < total_sum_.size(); ++i) {
842     total_sums->add_value()->set_float_value(total_sum_[i]);
843     total_squares->add_value()->set_float_value(total_sum_squares_[i]);
844   }
845 
846   for (int split_num = 0; split_num < num_splits(); ++split_num) {
847     auto* cand = slot->add_candidates();
848     *cand->mutable_split() = splits_[split_num];
849     auto* sums =
850         cand->mutable_left_stats()->mutable_regression()->mutable_mean_output();
851     auto* squares = cand->mutable_left_stats()
852                         ->mutable_regression()
853                         ->mutable_mean_output_squares();
854     for (int i = 0; i < num_outputs; ++i) {
855       sums->add_value()->set_float_value(left_sum(split_num, i));
856       squares->add_value()->set_float_value(left_square(split_num, i));
857     }
858     cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]);
859   }
860 }
861 
AddExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)862 void LeastSquaresRegressionGrowStats::AddExample(
863     const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
864     int example) {
865   const int32 num_outputs = params_.num_outputs();
866   // Update splits.
867   for (int i = 0; i < num_splits(); ++i) {
868     auto& eval = evaluators_[i];
869     if (eval->Decide(input_data, example) == LEFT_INDEX) {
870       for (int j = 0; j < num_outputs; ++j) {
871         const float output = target->GetTargetAsContinuous(example, j);
872         left_sum(i, j) += output;
873         left_square(i, j) += output * output;
874       }
875       ++left_counts_[i];
876     }
877   }
878 
879   // Update totals.
880   for (int i = 0; i < num_outputs; ++i) {
881     const float output = target->GetTargetAsContinuous(example, i);
882     total_sum_[i] += output;
883     total_sum_squares_[i] += output * output;
884   }
885   weight_sum_ += 1.0;
886 }
887 
SplitVariance(int split) const888 float LeastSquaresRegressionGrowStats::SplitVariance(int split) const {
889   float total_variance = 0;
890   for (int i = 0; i < params_.num_outputs(); ++i) {
891     // Left side
892     const float le_x = left_sum(split, i) / left_counts_[split];
893 
894     const float le_x2 = left_square(split, i) / left_counts_[split];
895     total_variance += le_x2 - le_x * le_x;
896 
897     // Right side
898     const float re_x = (total_sum_[i] - left_sum(split, i)) /
899                        (weight_sum_ - left_counts_[split]);
900 
901     const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) /
902                         (weight_sum_ - left_counts_[split]);
903     total_variance += re_x2 - re_x * re_x;
904   }
905   return total_variance;
906 }
907 
BestSplit(SplitCandidate * best) const908 bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const {
909   float min_score = FLT_MAX;
910   int best_index = -1;
911   const int32 num_outputs = params_.num_outputs();
912   for (int i = 0; i < num_splits(); ++i) {
913     if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) {
914       const float split_score = SplitVariance(i);
915       if (split_score < min_score) {
916         min_score = split_score;
917         best_index = i;
918       }
919     }
920   }
921 
922   // This could happen if all the splits are useless.
923   if (best_index < 0) {
924     return false;
925   }
926 
927   // Fill in right stats to be used for leaf model.
928   *best->mutable_split() = splits_[best_index];
929   // Left
930   auto* left = best->mutable_left_stats();
931   auto* left_reg_stats = left->mutable_regression();
932   left->set_weight_sum(left_counts_[best_index]);
933   auto* left_output_sum = left_reg_stats->mutable_mean_output();
934   for (int i = 0; i < num_outputs; ++i) {
935     left_output_sum->add_value()->set_float_value(left_sum(best_index, i));
936   }
937 
938   // Right
939   auto* right = best->mutable_right_stats();
940   auto* right_reg_stats = right->mutable_regression();
941   right->set_weight_sum(weight_sum_ - left_counts_[best_index]);
942   auto* right_output_sum = right_reg_stats->mutable_mean_output();
943   for (int i = 0; i < num_outputs; ++i) {
944     right_output_sum->add_value()->set_float_value(total_sum_[i] -
945                                                    left_sum(best_index, i));
946   }
947   return true;
948 }
949 
IsFinished() const950 bool LeastSquaresRegressionGrowStats::IsFinished() const {
951   return weight_sum_ >= split_after_samples_;
952 }
953 
954 }  // namespace tensorforest
955 }  // namespace tensorflow
956