1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h"
16
17 #include <cfloat>
18 #include <queue>
19 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
21 #include "tensorflow/core/lib/random/distribution_sampler.h"
22 #include "tensorflow/core/lib/random/random.h"
23
24 namespace tensorflow {
25 namespace tensorforest {
26
27 // When creating evaluators for the split candidates, use these
28 // for the left and right return values.
29 static const int32 LEFT_INDEX = 0;
30 static const int32 RIGHT_INDEX = 1;
31
GrowStats(const TensorForestParams & params,int32 depth)32 GrowStats::GrowStats(const TensorForestParams& params, int32 depth)
33 : weight_sum_(0),
34 depth_(depth),
35 params_(params),
36 split_after_samples_(ResolveParam(params.split_after_samples(), depth)),
37 num_splits_to_consider_(
38 ResolveParam(params.num_splits_to_consider(), depth)),
39 num_outputs_(params.num_outputs()) {}
40
AddSplit(const decision_trees::BinaryNode & split,const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)41 void GrowStats::AddSplit(const decision_trees::BinaryNode& split,
42 const std::unique_ptr<TensorDataSet>& input_data,
43 const InputTarget* target, int example) {
44 // It's possible that the split collection calls AddSplit, but we actually
45 // have all the splits we need and are just waiting for them to be fully
46 // initialized.
47 if (splits_.size() < num_splits_to_consider_) {
48 splits_.push_back(split);
49 evaluators_.emplace_back(
50 CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
51 AddSplitStats(target, example);
52 }
53
54 if (input_data != nullptr && target != nullptr &&
55 params_.initialize_average_splits()) {
56 AdditionalInitializationExample(input_data, target, example);
57 }
58 }
59
RemoveSplit(int split_num)60 void GrowStats::RemoveSplit(int split_num) {
61 splits_.erase(splits_.begin() + split_num);
62 evaluators_.erase(evaluators_.begin() + split_num);
63 RemoveSplitStats(split_num);
64 }
65
66 // ------------------------ Classification --------------------------- //
67
ClassificationStats(const TensorForestParams & params,int32 depth)68 ClassificationStats::ClassificationStats(const TensorForestParams& params,
69 int32 depth)
70 : GrowStats(params, depth), finish_early_(false) {
71 // Early splitting params.
72 if (params.finish_type().type() == SPLIT_FINISH_BASIC) {
73 min_split_samples_ = split_after_samples_;
74 finish_sample_epoch_ = 1;
75 finish_check_every_ = split_after_samples_ * 2;
76 } else {
77 if (!params.has_dominate_fraction() || !params.has_min_split_samples()) {
78 LOG(FATAL) << "dominate_fraction and min_split_samples "
79 << "required for early-finish strategy.";
80 } else {
81 min_split_samples_ = ResolveParam(params.min_split_samples(), depth);
82 finish_check_every_ =
83 ResolveParam(params.finish_type().check_every_steps(), depth);
84 finish_sample_epoch_ = min_split_samples_ / finish_check_every_;
85
86 dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
87 if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) {
88 LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_;
89 }
90 }
91 }
92
93 // Pruning params.
94 if (params.pruning_type().type() != SPLIT_PRUNE_NONE) {
95 prune_check_every_ =
96 ResolveParam(params.pruning_type().prune_every_samples(), depth);
97 prune_sample_epoch_ = 1;
98 prune_fraction_ = 0.0;
99 switch (params_.pruning_type().type()) {
100 case SPLIT_PRUNE_HALF:
101 prune_fraction_ = 0.5;
102 break;
103 case SPLIT_PRUNE_QUARTER:
104 prune_fraction_ = 0.25;
105 break;
106 case SPLIT_PRUNE_10_PERCENT:
107 prune_fraction_ = 0.10;
108 break;
109 case SPLIT_PRUNE_HOEFFDING:
110 dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
111 half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_));
112 break;
113 default:
114 LOG(WARNING) << "Unknown pruning type";
115 }
116 } else {
117 prune_check_every_ = split_after_samples_ * 2;
118 prune_sample_epoch_ = 1;
119 }
120
121 if (params.use_running_stats_method()) {
122 left_gini_.reset(new RunningGiniScores());
123 right_gini_.reset(new RunningGiniScores());
124 }
125
126 single_rand_ = std::unique_ptr<random::PhiloxRandom>(
127 new random::PhiloxRandom(random::New64()));
128 rng_ = std::unique_ptr<random::SimplePhilox>(
129 new random::SimplePhilox(single_rand_.get()));
130 }
131
AdditionalInitializationExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)132 void ClassificationStats::AdditionalInitializationExample(
133 const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
134 int example) {
135 const int32 new_target = target->GetTargetAsClassIndex(example, 0);
136 std::unordered_set<int> to_erase;
137 for (auto it = half_initialized_splits_.begin();
138 it != half_initialized_splits_.end(); ++it) {
139 if (it->second != new_target) {
140 auto& split = splits_[it->first];
141 if (split.has_inequality_left_child_test()) {
142 auto& test = split.inequality_left_child_test();
143 auto* thresh =
144 split.mutable_inequality_left_child_test()->mutable_threshold();
145 if (test.has_feature_id()) {
146 const float val =
147 input_data->GetExampleValue(example, test.feature_id());
148 thresh->set_float_value((thresh->float_value() + val) / 2);
149 }
150 }
151 to_erase.insert(it->first);
152 }
153 }
154
155 for (const int split_id : to_erase) {
156 half_initialized_splits_.erase(split_id);
157 }
158 }
159
IsFinished() const160 bool ClassificationStats::IsFinished() const {
161 bool basic = (weight_sum_ >= split_after_samples_) && !is_pure();
162 return basic || finish_early_;
163 }
164
MaybeCachedGiniScore(int split,float * left_sum,float * right_sum) const165 float ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum,
166 float* right_sum) const {
167 if (left_gini_ == nullptr) {
168 return GiniScore(split, left_sum, right_sum);
169 } else {
170 *left_sum = left_gini_->sum(split);
171 const float left = WeightedSmoothedGini(
172 *left_sum, left_gini_->square(split), num_outputs_);
173
174 *right_sum = right_gini_->sum(split);
175 const float right = WeightedSmoothedGini(
176 *right_sum, right_gini_->square(split), num_outputs_);
177
178 return left + right;
179 }
180 }
181
AddExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)182 void ClassificationStats::AddExample(
183 const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
184 int example) {
185 const int64 int_label = target->GetTargetAsClassIndex(example, 0);
186 const float weight = target->GetTargetWeight(example);
187
188 for (int i = 0; i < num_splits(); ++i) {
189 auto& eval = evaluators_[i];
190 if (eval->Decide(input_data, example) == LEFT_INDEX) {
191 if (left_gini_ != nullptr) {
192 left_gini_->update(i, left_count(i, int_label), weight);
193 }
194 ClassificationAddLeftExample(i, int_label, weight);
195 } else {
196 if (right_gini_ != nullptr) {
197 right_gini_->update(i, right_count(i, int_label), weight);
198 }
199 ClassificationAddRightExample(i, int_label, weight);
200 }
201 }
202
203 ClassificationAddTotalExample(int_label, weight);
204
205 weight_sum_ += weight;
206
207 CheckFinishEarly();
208 CheckPrune();
209 }
210
CheckPrune()211 void ClassificationStats::CheckPrune() {
212 if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() ||
213 weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
214 return;
215 }
216 ++prune_sample_epoch_;
217
218 if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) {
219 CheckPruneHoeffding();
220 return;
221 }
222
223 const int to_remove = num_splits() * prune_fraction_;
224 if (to_remove <= 0) {
225 return;
226 }
227
228 // pair ordering is first-then-second by default, no need for custom
229 // comparison. Use std::greater to make it a min-heap.
230 std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
231 std::greater<std::pair<float, int>>>
232 worst;
233
234 // Track indices that are in the heap so we can iterate over them
235 // by largest-first later.
236 std::set<int> indices;
237
238 for (int i = 0; i < num_splits(); ++i) {
239 float left, right;
240 const float split_score = MaybeCachedGiniScore(i, &left, &right);
241 if (worst.size() < to_remove) {
242 worst.push(std::pair<float, int>(split_score, i));
243 indices.insert(i);
244 } else if (worst.top().first < split_score) {
245 indices.erase(worst.top().second);
246 worst.pop();
247 worst.push(std::pair<float, int>(split_score, i));
248 indices.insert(i);
249 }
250 }
251
252 // traverse indices from the back so that they are removed correctly.
253 for (auto it = indices.rbegin(); it != indices.rend(); ++it) {
254 RemoveSplit(*it);
255 }
256 }
257
CheckPruneHoeffding()258 void ClassificationStats::CheckPruneHoeffding() {
259 std::vector<float> split_scores(num_splits());
260 // Find best split score
261 float best_split_score = FLT_MAX;
262 for (int i = 0; i < num_splits(); ++i) {
263 float left, right;
264 split_scores[i] = MaybeCachedGiniScore(i, &left, &right);
265 if (split_scores[i] < best_split_score) {
266 best_split_score = split_scores[i];
267 }
268 }
269
270 // We apply the Hoeffding bound to the difference between the best split
271 // score and the i-th split score.
272 // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted.
273 const float num_classes = params_.num_outputs();
274 const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes);
275 float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_);
276 for (int i = num_splits() - 1; i >= 0; i--) {
277 if (split_scores[i] - best_split_score > epsilon) {
278 RemoveSplit(i);
279 }
280 }
281 }
282
CheckFinishEarly()283 void ClassificationStats::CheckFinishEarly() {
284 if (weight_sum_ < min_split_samples_ ||
285 weight_sum_ < finish_sample_epoch_ * finish_check_every_) {
286 return;
287 }
288 ++finish_sample_epoch_;
289
290 if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) {
291 CheckFinishEarlyHoeffding();
292 } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) {
293 CheckFinishEarlyBootstrap();
294 }
295 }
296
CheckFinishEarlyHoeffding()297 void ClassificationStats::CheckFinishEarlyHoeffding() {
298 // Each term in the Gini impurity can range from 0 to 0.5 * 0.5.
299 float range = 0.25 * static_cast<float>(params_.num_outputs()) * weight_sum_;
300
301 float hoeffding_bound =
302 range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_));
303
304 float unused_left_sum, unused_right_sum;
305 std::function<float(int)> score_fn =
306 std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
307 std::placeholders::_1, &unused_left_sum, &unused_right_sum);
308
309 float best_score;
310 int32 best_index;
311 float second_best_score;
312 int32 second_best_index;
313 GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
314 &second_best_score, &second_best_index);
315
316 finish_early_ = (second_best_score - best_score) > hoeffding_bound;
317 }
318
MakeBootstrapWeights(int index,std::vector<float> * weights)319 void ClassificationStats::MakeBootstrapWeights(int index,
320 std::vector<float>* weights) {
321 int n = weight_sum_;
322 float denom = static_cast<float>(n) + static_cast<float>(num_outputs_);
323 for (int i = 0; i < num_outputs_; ++i) {
324 // Use the Laplace smoothed per-class probabilities when generating the
325 // bootstrap samples.
326 (*weights)[i] = (left_count(index, i) + 1.0) / denom;
327 (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom;
328 }
329 }
330
NumBootstrapSamples() const331 int ClassificationStats::NumBootstrapSamples() const {
332 float p = 1.0 - dominate_fraction_;
333 int bootstrap_samples = 1;
334 while (p < 1.0) {
335 ++bootstrap_samples;
336 p = p * 2;
337 }
338 return bootstrap_samples;
339 }
340
CheckFinishEarlyBootstrap()341 void ClassificationStats::CheckFinishEarlyBootstrap() {
342 float unused_left_sum, unused_right_sum;
343 std::function<float(int)> score_fn =
344 std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
345 std::placeholders::_1, &unused_left_sum, &unused_right_sum);
346
347 float best_score;
348 int32 best_index;
349 float second_best_score;
350 int32 second_best_index;
351 GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
352 &second_best_score, &second_best_index);
353
354 std::vector<float> weights1(num_outputs_ * 2);
355 MakeBootstrapWeights(best_index, &weights1);
356 random::DistributionSampler ds1(weights1);
357
358 std::vector<float> weights2(num_outputs_ * 2);
359 MakeBootstrapWeights(second_best_index, &weights2);
360 random::DistributionSampler ds2(weights2);
361
362 const int bootstrap_samples = NumBootstrapSamples();
363
364 int worst_g1 = 0;
365 for (int i = 0; i < bootstrap_samples; i++) {
366 int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get());
367 worst_g1 = std::max(worst_g1, g1);
368 }
369
370 int best_g2 = 99;
371 for (int i = 0; i < bootstrap_samples; i++) {
372 int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get());
373 best_g2 = std::min(best_g2, g2);
374 }
375
376 finish_early_ = worst_g1 < best_g2;
377 }
378
BestSplit(SplitCandidate * best) const379 bool ClassificationStats::BestSplit(SplitCandidate* best) const {
380 float min_score = FLT_MAX;
381 int best_index = -1;
382 float best_left_sum, best_right_sum;
383
384 // Calculate sums.
385 for (int i = 0; i < num_splits(); ++i) {
386 float left_sum, right_sum;
387 const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
388 // Find the lowest gini.
389 if (left_sum > 0 && right_sum > 0 &&
390 split_score < min_score) { // useless check
391 min_score = split_score;
392 best_index = i;
393 best_left_sum = left_sum;
394 best_right_sum = right_sum;
395 }
396 }
397
398 // This could happen if all the splits are useless.
399 if (best_index < 0) {
400 return false;
401 }
402
403 // Fill in stats to be used for leaf model.
404 *best->mutable_split() = splits_[best_index];
405 auto* left = best->mutable_left_stats();
406 left->set_weight_sum(best_left_sum);
407 auto* right = best->mutable_right_stats();
408 right->set_weight_sum(best_right_sum);
409 InitLeafClassStats(best_index, left, right);
410
411 return true;
412 }
413
414 // ------------------------ Dense Classification --------------------------- //
ExtractFromProto(const FertileSlot & slot)415 void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
416 Initialize();
417 if (!slot.has_post_init_leaf_stats()) {
418 return;
419 }
420 const int32 num_classes = params_.num_outputs();
421 weight_sum_ = slot.post_init_leaf_stats().weight_sum();
422 const auto& class_stats =
423 slot.post_init_leaf_stats().classification().dense_counts();
424
425 // Total counts.
426 for (int i = 0; i < num_classes; ++i) {
427 total_counts_[i] = class_stats.value(i).float_value();
428 num_outputs_seen_ += total_counts_[i] != 0;
429 }
430
431 // Candidate counts and splits.
432 int split_num = 0;
433 for (const auto& cand : slot.candidates()) {
434 AddSplit(cand.split(), nullptr, nullptr, -1);
435 const auto& left_stats = cand.left_stats().classification().dense_counts();
436 for (int i = 0; i < num_classes; ++i) {
437 const float val = left_stats.value(i).float_value();
438 mutable_left_count(split_num, i) = val;
439 MaybeInitializeRunningCount(split_num, val);
440 }
441 ++split_num;
442 }
443 }
444
PackToProto(FertileSlot * slot) const445 void DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
446 auto* slot_stats = slot->mutable_post_init_leaf_stats();
447 slot_stats->set_weight_sum(weight_sum_);
448
449 auto* class_stats = slot->mutable_post_init_leaf_stats()
450 ->mutable_classification()
451 ->mutable_dense_counts();
452 for (int i = 0; i < num_outputs_; ++i) {
453 class_stats->add_value()->set_float_value(total_counts_[i]);
454 }
455
456 for (int split_num = 0; split_num < num_splits(); ++split_num) {
457 auto* cand = slot->add_candidates();
458 *cand->mutable_split() = splits_[split_num];
459 auto* left_stats = cand->mutable_left_stats()
460 ->mutable_classification()
461 ->mutable_dense_counts();
462 for (int i = 0; i < num_outputs_; ++i) {
463 left_stats->add_value()->set_float_value(left_count(split_num, i));
464 }
465 }
466 }
467
GiniScore(int split,float * left_sum,float * right_sum) const468 float DenseClassificationGrowStats::GiniScore(int split, float* left_sum,
469 float* right_sum) const {
470 float left_square = 0, right_square = 0;
471 *left_sum = 0;
472 *right_sum = 0;
473 for (int j = 0; j < num_outputs_; ++j) {
474 const float left = left_count(split, j);
475 *left_sum += left;
476 left_square += left * left;
477 const float right = right_count(split, j);
478 *right_sum += right;
479 right_square += right * right;
480 }
481
482 const float left_score =
483 WeightedSmoothedGini(*left_sum, left_square, num_outputs_);
484 const float right_score =
485 WeightedSmoothedGini(*right_sum, right_square, num_outputs_);
486 return left_score + right_score;
487 }
488
InitLeafClassStats(int best_split_index,LeafStat * left_stats,LeafStat * right_stats) const489 void DenseClassificationGrowStats::InitLeafClassStats(
490 int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
491 auto* left_class_stats = left_stats->mutable_classification();
492 auto* left_counts = left_class_stats->mutable_dense_counts();
493 for (int i = 0; i < params_.num_outputs(); ++i) {
494 left_counts->add_value()->set_float_value(left_count(best_split_index, i));
495 }
496
497 auto* right_class_stats = right_stats->mutable_classification();
498 auto* right_counts = right_class_stats->mutable_dense_counts();
499 for (int i = 0; i < params_.num_outputs(); ++i) {
500 right_counts->add_value()->set_float_value(total_counts_[i] -
501 left_count(best_split_index, i));
502 }
503 }
504
505 // ------------------------ Sparse Classification --------------------------- //
ExtractFromProto(const FertileSlot & slot)506 void SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
507 Initialize();
508 if (!slot.has_post_init_leaf_stats()) {
509 return;
510 }
511 weight_sum_ = slot.post_init_leaf_stats().weight_sum();
512 const auto& class_stats =
513 slot.post_init_leaf_stats().classification().sparse_counts();
514
515 // Total counts.
516 for (auto const& entry : class_stats.sparse_value()) {
517 total_counts_[entry.first] = entry.second.float_value();
518 }
519
520 // Candidate counts and splits.
521 int split_num = 0;
522 for (const auto& cand : slot.candidates()) {
523 AddSplit(cand.split(), nullptr, nullptr, -1);
524 const auto& left_stats = cand.left_stats().classification().sparse_counts();
525 for (auto const& entry : left_stats.sparse_value()) {
526 const float val = entry.second.float_value();
527 left_counts_[split_num][entry.first] = val;
528 MaybeInitializeRunningCount(split_num, val);
529 }
530 ++split_num;
531 }
532 }
533
PackToProto(FertileSlot * slot) const534 void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
535 auto* slot_stats = slot->mutable_post_init_leaf_stats();
536 slot_stats->set_weight_sum(weight_sum_);
537
538 auto* class_stats = slot->mutable_post_init_leaf_stats()
539 ->mutable_classification()
540 ->mutable_sparse_counts()
541 ->mutable_sparse_value();
542 for (const auto& entry : total_counts_) {
543 decision_trees::Value val;
544 val.set_float_value(entry.second);
545 (*class_stats)[entry.first] = val;
546 }
547
548 for (int split_num = 0; split_num < num_splits(); ++split_num) {
549 auto* cand = slot->add_candidates();
550 *cand->mutable_split() = splits_[split_num];
551 auto* left_stats = cand->mutable_left_stats()
552 ->mutable_classification()
553 ->mutable_sparse_counts()
554 ->mutable_sparse_value();
555 for (const auto& entry : left_counts_[split_num]) {
556 decision_trees::Value val;
557 val.set_float_value(entry.second);
558 (*left_stats)[entry.first] = val;
559 }
560 }
561 }
562
GiniScore(int split,float * left_sum,float * right_sum) const563 float SparseClassificationGrowStats::GiniScore(int split, float* left_sum,
564 float* right_sum) const {
565 float left_square = 0, right_square = 0;
566 *left_sum = 0;
567 *right_sum = 0;
568 for (const auto& entry : total_counts_) {
569 const int label = entry.first;
570 float left = 0;
571 float right = 0;
572 auto it = left_counts_[split].find(label);
573 if (it == left_counts_[split].end()) {
574 right = entry.second;
575 } else {
576 left = it->second;
577 right = entry.second - it->second;
578 }
579 *left_sum += left;
580 left_square += left * left;
581 *right_sum += right;
582 right_square += right * right;
583 }
584 const int32 num_classes = params_.num_outputs();
585 const float left_score =
586 WeightedSmoothedGini(*left_sum, left_square, num_classes);
587 const float right_score =
588 WeightedSmoothedGini(*right_sum, right_square, num_classes);
589 return left_score + right_score;
590 }
591
InitLeafClassStats(int best_split_index,LeafStat * left_stats,LeafStat * right_stats) const592 void SparseClassificationGrowStats::InitLeafClassStats(
593 int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
594 auto* left_class_stats = left_stats->mutable_classification();
595 auto* left_counts =
596 left_class_stats->mutable_sparse_counts()->mutable_sparse_value();
597 auto* right_class_stats = right_stats->mutable_classification();
598 auto* right_counts =
599 right_class_stats->mutable_sparse_counts()->mutable_sparse_value();
600
601 for (const auto& entry : total_counts_) {
602 auto it = left_counts_[best_split_index].find(entry.first);
603 if (it == left_counts_[best_split_index].end()) {
604 (*right_counts)[entry.first].set_float_value(entry.second);
605 } else {
606 const float left = it->second;
607 const float right = entry.second - it->second;
608 (*left_counts)[entry.first].set_float_value(left);
609 if (right > 0) {
610 (*right_counts)[entry.first].set_float_value(right);
611 }
612 }
613 }
614 }
615
616 // -------------------- FixedSizeClassStats --------------------------------- //
617
618 // FixedSizeClassStats implements the "SpaceSaving" algorithm by
619 // Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi. See for example
620 // https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf
621
argmin(const std::unordered_map<int,float> & m)622 int argmin(const std::unordered_map<int, float>& m) {
623 int c = -1;
624 float f = FLT_MAX;
625 for (const auto it : m) {
626 if (it.second < f) {
627 f = it.second;
628 c = it.first;
629 }
630 }
631 return c;
632 }
633
accumulate(int c,float w)634 void FixedSizeClassStats::accumulate(int c, float w) {
635 auto it = class_weights_.find(c);
636 if (it != class_weights_.end()) {
637 it->second += w;
638 if (c == smallest_weight_class_) {
639 smallest_weight_class_ = argmin(class_weights_);
640 }
641 return;
642 }
643
644 if (class_weights_.size() < n_) {
645 class_weights_.insert(it, std::pair<int, float>(c, w));
646 if (class_weights_.size() == n_) {
647 // Can't assume last added has the smallest weight, because the
648 // w's might be all different.
649 smallest_weight_class_ = argmin(class_weights_);
650 }
651 return;
652 }
653
654 // This is the slightly unintuitive heart of the SpaceSaving algorithm:
655 // if the map is full and we see a new class, we find the entry with the
656 // smallest weight and "take it over": we add our weight to its weight,
657 // and assign it all to the new seen class.
658 it = class_weights_.find(smallest_weight_class_);
659 float new_weight = it->second + w;
660 class_weights_.erase(it);
661 class_weights_[c] = new_weight;
662 smallest_weight_class_ = argmin(class_weights_);
663 }
664
get_weight(int c) const665 float FixedSizeClassStats::get_weight(int c) const {
666 // Every entry in class_weights_ might be overstated by as much as the
667 // smallest_weight. We therefore assume that each has been overstated
668 // by smallest_weight / 2.0, and we re-distribute that mass over all
669 // num_classes_ classes.
670 float smallest_weight = 0.0;
671 auto it = class_weights_.find(smallest_weight_class_);
672 if (it != class_weights_.end()) {
673 smallest_weight = it->second;
674 }
675 float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
676 it = class_weights_.find(c);
677 if (it != class_weights_.end()) {
678 w += it->second - smallest_weight / 2.0;
679 }
680 return w;
681 }
682
set_sum_and_square(float * sum,float * square) const683 void FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const {
684 *sum = 0.0;
685 *square = 0.0;
686
687 float smallest_weight = 0.0;
688 auto it = class_weights_.find(smallest_weight_class_);
689 if (it != class_weights_.end()) {
690 smallest_weight = it->second;
691 }
692
693 float w;
694 for (const auto it : class_weights_) {
695 *sum += it.second;
696 w = get_weight(it.first);
697 *square += w * w;
698 }
699
700 w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
701 *square += (num_classes_ - n_) * w * w;
702 }
703
ExtractFromProto(const decision_trees::SparseVector & sparse_vector)704 void FixedSizeClassStats::ExtractFromProto(
705 const decision_trees::SparseVector& sparse_vector) {
706 for (const auto& it : sparse_vector.sparse_value()) {
707 class_weights_[it.first] = it.second.float_value();
708 }
709 if (class_weights_.size() == n_) {
710 smallest_weight_class_ = argmin(class_weights_);
711 }
712 }
713
PackToProto(decision_trees::SparseVector * sparse_vector) const714 void FixedSizeClassStats::PackToProto(
715 decision_trees::SparseVector* sparse_vector) const {
716 for (const auto it : class_weights_) {
717 (*sparse_vector->mutable_sparse_value())[it.first].set_float_value(
718 it.second);
719 }
720 }
721
722 // --------------------- FixedSizeSparseClassificationGrowStats ------------- //
723
ExtractFromProto(const FertileSlot & slot)724 void FixedSizeSparseClassificationGrowStats::ExtractFromProto(
725 const FertileSlot& slot) {
726 Initialize();
727 if (!slot.has_post_init_leaf_stats()) {
728 return;
729 }
730 weight_sum_ = slot.post_init_leaf_stats().weight_sum();
731
732 // Candidate counts and splits.
733 int split_num = 0;
734 left_counts_.clear();
735 right_counts_.clear();
736 for (const auto& cand : slot.candidates()) {
737 AddSplit(cand.split(), nullptr, nullptr, -1);
738 const auto& left_stats = cand.left_stats().classification().sparse_counts();
739 left_counts_.emplace_back(params_.num_classes_to_track(),
740 params_.num_outputs());
741 left_counts_[split_num].ExtractFromProto(left_stats);
742 const auto& right_stats =
743 cand.right_stats().classification().sparse_counts();
744 right_counts_.emplace_back(params_.num_classes_to_track(),
745 params_.num_outputs());
746 right_counts_[split_num].ExtractFromProto(right_stats);
747 ++split_num;
748 }
749 }
750
PackToProto(FertileSlot * slot) const751 void FixedSizeSparseClassificationGrowStats::PackToProto(
752 FertileSlot* slot) const {
753 auto* slot_stats = slot->mutable_post_init_leaf_stats();
754 slot_stats->set_weight_sum(weight_sum_);
755
756 for (int split_num = 0; split_num < num_splits(); ++split_num) {
757 auto* cand = slot->add_candidates();
758 *cand->mutable_split() = splits_[split_num];
759 auto* left_stats = cand->mutable_left_stats()
760 ->mutable_classification()
761 ->mutable_sparse_counts();
762 left_counts_[split_num].PackToProto(left_stats);
763 auto* right_stats = cand->mutable_right_stats()
764 ->mutable_classification()
765 ->mutable_sparse_counts();
766 right_counts_[split_num].PackToProto(right_stats);
767 }
768 }
769
GiniScore(int split,float * left_sum,float * right_sum) const770 float FixedSizeSparseClassificationGrowStats::GiniScore(
771 int split, float* left_sum, float* right_sum) const {
772 float left_square, right_square;
773 left_counts_[split].set_sum_and_square(left_sum, &left_square);
774 right_counts_[split].set_sum_and_square(right_sum, &right_square);
775 const int32 num_classes = params_.num_outputs();
776 const float left_score =
777 WeightedSmoothedGini(*left_sum, left_square, num_classes);
778 const float right_score =
779 WeightedSmoothedGini(*right_sum, right_square, num_classes);
780 return left_score + right_score;
781 }
782
InitLeafClassStats(int best_split_index,LeafStat * left_stats,LeafStat * right_stats) const783 void FixedSizeSparseClassificationGrowStats::InitLeafClassStats(
784 int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
785 auto* left_class_stats = left_stats->mutable_classification();
786 auto* left_counts = left_class_stats->mutable_sparse_counts();
787 left_counts_[best_split_index].PackToProto(left_counts);
788
789 auto* right_class_stats = right_stats->mutable_classification();
790 auto* right_counts = right_class_stats->mutable_sparse_counts();
791 right_counts_[best_split_index].PackToProto(right_counts);
792 }
793
794 // --------------------- Least Squares Regression --------------------------- //
ExtractFromProto(const FertileSlot & slot)795 void LeastSquaresRegressionGrowStats::ExtractFromProto(
796 const FertileSlot& slot) {
797 const int32 num_outputs = params_.num_outputs();
798 Initialize();
799 if (!slot.has_post_init_leaf_stats()) {
800 return;
801 }
802 weight_sum_ = slot.post_init_leaf_stats().weight_sum();
803 const auto& total_sums =
804 slot.post_init_leaf_stats().regression().mean_output();
805 const auto& total_squares =
806 slot.post_init_leaf_stats().regression().mean_output_squares();
807
808 // Total counts.
809 for (int i = 0; i < num_outputs; ++i) {
810 total_sum_[i] = total_sums.value(i).float_value();
811 total_sum_squares_[i] = total_squares.value(i).float_value();
812 }
813
814 // Candidate counts and splits.
815 int split_num = 0;
816 for (const auto& cand : slot.candidates()) {
817 AddSplit(cand.split(), nullptr, nullptr, -1);
818 const auto& sums = cand.left_stats().regression().mean_output();
819 const auto& squares = cand.left_stats().regression().mean_output_squares();
820 for (int i = 0; i < num_outputs; ++i) {
821 left_sum(split_num, i) = sums.value(i).float_value();
822 left_square(split_num, i) = squares.value(i).float_value();
823 }
824 left_counts_[split_num] = cand.left_stats().weight_sum();
825 ++split_num;
826 }
827 }
828
PackToProto(FertileSlot * slot) const829 void LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const {
830 const int32 num_outputs = params_.num_outputs();
831 auto* slot_stats = slot->mutable_post_init_leaf_stats();
832 slot_stats->set_weight_sum(weight_sum_);
833
834 auto* total_sums = slot->mutable_post_init_leaf_stats()
835 ->mutable_regression()
836 ->mutable_mean_output();
837 auto* total_squares = slot->mutable_post_init_leaf_stats()
838 ->mutable_regression()
839 ->mutable_mean_output_squares();
840
841 for (int i = 0; i < total_sum_.size(); ++i) {
842 total_sums->add_value()->set_float_value(total_sum_[i]);
843 total_squares->add_value()->set_float_value(total_sum_squares_[i]);
844 }
845
846 for (int split_num = 0; split_num < num_splits(); ++split_num) {
847 auto* cand = slot->add_candidates();
848 *cand->mutable_split() = splits_[split_num];
849 auto* sums =
850 cand->mutable_left_stats()->mutable_regression()->mutable_mean_output();
851 auto* squares = cand->mutable_left_stats()
852 ->mutable_regression()
853 ->mutable_mean_output_squares();
854 for (int i = 0; i < num_outputs; ++i) {
855 sums->add_value()->set_float_value(left_sum(split_num, i));
856 squares->add_value()->set_float_value(left_square(split_num, i));
857 }
858 cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]);
859 }
860 }
861
AddExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)862 void LeastSquaresRegressionGrowStats::AddExample(
863 const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
864 int example) {
865 const int32 num_outputs = params_.num_outputs();
866 // Update splits.
867 for (int i = 0; i < num_splits(); ++i) {
868 auto& eval = evaluators_[i];
869 if (eval->Decide(input_data, example) == LEFT_INDEX) {
870 for (int j = 0; j < num_outputs; ++j) {
871 const float output = target->GetTargetAsContinuous(example, j);
872 left_sum(i, j) += output;
873 left_square(i, j) += output * output;
874 }
875 ++left_counts_[i];
876 }
877 }
878
879 // Update totals.
880 for (int i = 0; i < num_outputs; ++i) {
881 const float output = target->GetTargetAsContinuous(example, i);
882 total_sum_[i] += output;
883 total_sum_squares_[i] += output * output;
884 }
885 weight_sum_ += 1.0;
886 }
887
SplitVariance(int split) const888 float LeastSquaresRegressionGrowStats::SplitVariance(int split) const {
889 float total_variance = 0;
890 for (int i = 0; i < params_.num_outputs(); ++i) {
891 // Left side
892 const float le_x = left_sum(split, i) / left_counts_[split];
893
894 const float le_x2 = left_square(split, i) / left_counts_[split];
895 total_variance += le_x2 - le_x * le_x;
896
897 // Right side
898 const float re_x = (total_sum_[i] - left_sum(split, i)) /
899 (weight_sum_ - left_counts_[split]);
900
901 const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) /
902 (weight_sum_ - left_counts_[split]);
903 total_variance += re_x2 - re_x * re_x;
904 }
905 return total_variance;
906 }
907
BestSplit(SplitCandidate * best) const908 bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const {
909 float min_score = FLT_MAX;
910 int best_index = -1;
911 const int32 num_outputs = params_.num_outputs();
912 for (int i = 0; i < num_splits(); ++i) {
913 if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) {
914 const float split_score = SplitVariance(i);
915 if (split_score < min_score) {
916 min_score = split_score;
917 best_index = i;
918 }
919 }
920 }
921
922 // This could happen if all the splits are useless.
923 if (best_index < 0) {
924 return false;
925 }
926
927 // Fill in right stats to be used for leaf model.
928 *best->mutable_split() = splits_[best_index];
929 // Left
930 auto* left = best->mutable_left_stats();
931 auto* left_reg_stats = left->mutable_regression();
932 left->set_weight_sum(left_counts_[best_index]);
933 auto* left_output_sum = left_reg_stats->mutable_mean_output();
934 for (int i = 0; i < num_outputs; ++i) {
935 left_output_sum->add_value()->set_float_value(left_sum(best_index, i));
936 }
937
938 // Right
939 auto* right = best->mutable_right_stats();
940 auto* right_reg_stats = right->mutable_regression();
941 right->set_weight_sum(weight_sum_ - left_counts_[best_index]);
942 auto* right_output_sum = right_reg_stats->mutable_mean_output();
943 for (int i = 0; i < num_outputs; ++i) {
944 right_output_sum->add_value()->set_float_value(total_sum_[i] -
945 left_sum(best_index, i));
946 }
947 return true;
948 }
949
IsFinished() const950 bool LeastSquaresRegressionGrowStats::IsFinished() const {
951 return weight_sum_ >= split_after_samples_;
952 }
953
954 } // namespace tensorforest
955 } // namespace tensorflow
956