1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
17 #include <unordered_map>
18 #include <vector>
19 
20 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
21 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
22 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
23 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
24 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
25 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
26 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
27 #include "tensorflow/core/lib/random/philox_random.h"
28 #include "tensorflow/core/lib/random/simple_philox.h"
29 
30 namespace tensorflow {
31 namespace tensorforest {
32 
33 // Base class for tracking stats necessary to split a leaf.
34 // Holds and tracks stats for every candidate split.
35 class GrowStats {
36  public:
~GrowStats()37   virtual ~GrowStats() {}
38   // Perform any initialization.
39   virtual void Initialize() = 0;
40 
41   // Add an example to any stats being collected.
42   virtual void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
43                           const InputTarget* target, int example) = 0;
44 
45   // Fill in the best split, return false if none were valid.
46   virtual bool BestSplit(SplitCandidate* best) const = 0;
47 
48   // Return true if this leaf is finished splitting.
49   virtual bool IsFinished() const = 0;
50 
51   // Get the split_num BinaryNode.
Split(int split_num)52   const decision_trees::BinaryNode& Split(int split_num) const {
53     return splits_[split_num];
54   }
55 
56   // Clear all state.
Clear()57   virtual void Clear() {
58     weight_sum_ = 0;
59     splits_.clear();
60     evaluators_.clear();
61     ClearInternal();
62   }
63 
64   virtual void ExtractFromProto(const FertileSlot& slot) = 0;
65   virtual void PackToProto(FertileSlot* slot) const = 0;
66 
67   // Add split to the list of candidate splits.
68   void AddSplit(const decision_trees::BinaryNode& split,
69                 const std::unique_ptr<TensorDataSet>& input_data,
70                 const InputTarget* target, int example);
AdditionalInitializationExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)71   virtual void AdditionalInitializationExample(
72       const std::unique_ptr<TensorDataSet>& input_data,
73       const InputTarget* target, int example) {}
74   void RemoveSplit(int split_num);
75 
num_splits()76   int num_splits() const { return splits_.size(); }
77 
weight_sum()78   float weight_sum() const { return weight_sum_; }
79 
IsInitialized()80   virtual bool IsInitialized() const {
81     return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_;
82   }
83 
depth()84   int32 depth() const { return depth_; }
85 
86  protected:
87   GrowStats(const TensorForestParams& params, int32 depth);
88 
89   // Function called by AddSplit for subclasses to initialize stats for a split.
90   virtual void AddSplitStats(const InputTarget* target, int example) = 0;
91 
92   virtual void RemoveSplitStats(int split_num) = 0;
93 
94   // Function called by Clear for subclasses to clear their state.
95   virtual void ClearInternal() = 0;
96 
97   std::vector<decision_trees::BinaryNode> splits_;
98   std::vector<std::unique_ptr<DecisionNodeEvaluator>> evaluators_;
99 
100   float weight_sum_;
101 
102   const int32 depth_;
103 
104   const TensorForestParams& params_;
105 
106   // We cache these because they're used often.
107   const int split_after_samples_;
108   const int num_splits_to_consider_;
109 
110   const int32 num_outputs_;
111 };
112 
113 // Don't track anything, useful for systems that want to track split
114 // candidates but train the model in some other way.
115 class SimpleStats : public GrowStats {
116  public:
SimpleStats(const TensorForestParams & params,int32 depth)117   SimpleStats(const TensorForestParams& params, int32 depth)
118       : GrowStats(params, depth) {}
Initialize()119   void Initialize() override {}
120 
ExtractFromProto(const FertileSlot & slot)121   void ExtractFromProto(const FertileSlot& slot) override {}
PackToProto(FertileSlot * slot)122   void PackToProto(FertileSlot* slot) const override {}
123 
AddExample(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,int example)124   void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
125                   const InputTarget* target, int example) override {
126     weight_sum_ += target->GetTargetWeight(example);
127   }
128 
BestSplit(SplitCandidate * best)129   bool BestSplit(SplitCandidate* best) const override { return false; }
130 
IsFinished()131   bool IsFinished() const override {
132     return weight_sum_ >= split_after_samples_;
133   }
134 
135  protected:
AddSplitStats(const InputTarget * target,int example)136   void AddSplitStats(const InputTarget* target, int example) override {}
RemoveSplitStats(int split_num)137   void RemoveSplitStats(int split_num) override {}
ClearInternal()138   void ClearInternal() override {}
139 };
140 
141 // Tracks the sum and square of one side of a split for each Gini calculation.
142 class RunningGiniScores {
143  public:
sum(int split)144   float sum(int split) const { return sum_[split]; }
square(int split)145   float square(int split) const { return square_[split]; }
146 
update(int split,float old_val,float weight)147   void update(int split, float old_val, float weight) {
148     sum_[split] += weight;
149     const float new_val = old_val + weight;
150     square_[split] = square_[split] - old_val * old_val + new_val * new_val;
151   }
152 
add_split()153   void add_split() {
154     sum_.push_back(0);
155     square_.push_back(0);
156   }
157 
remove_split(int i)158   void remove_split(int i) {
159     sum_.erase(sum_.begin() + i);
160     square_.erase(square_.begin() + i);
161   }
162 
163  private:
164   std::vector<float> sum_;
165   std::vector<float> square_;
166 };
167 
168 class ClassificationStats : public GrowStats {
169  public:
170   ClassificationStats(const TensorForestParams& params, int32 depth);
171 
172   bool IsFinished() const override;
173 
174   void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
175                   const InputTarget* target, int example) override;
176 
177   void AdditionalInitializationExample(
178       const std::unique_ptr<TensorDataSet>& input_data,
179       const InputTarget* target, int example) override;
180 
IsInitialized()181   bool IsInitialized() const override {
182     return weight_sum_ > 0 || (splits_.size() == num_splits_to_consider_ &&
183                                half_initialized_splits_.empty());
184   }
185 
186   bool BestSplit(SplitCandidate* best) const override;
187   // When best_split_index has been chosen as the best split,
188   // InitLeafClassStats is used to initialize the LeafStat's of the two
189   // new leaves.
190   virtual void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
191                                   LeafStat* right_stats) const = 0;
192 
193  protected:
194   virtual float GiniScore(int split, float* left_sum,
195                           float* right_sum) const = 0;
196 
197   // is_pure should return true if at most one class label has been seen
198   // at the node, and false if two or more have been seen.
199   virtual bool is_pure() const = 0;
200   virtual float left_count(int split, int class_num) const = 0;
201   virtual float right_count(int split, int class_num) const = 0;
202 
203   virtual void ClassificationAddLeftExample(int split, int64 int_label,
204                                             float weight) = 0;
ClassificationAddRightExample(int split,int64 int_label,float weight)205   virtual void ClassificationAddRightExample(int split, int64 int_label,
206                                              float weight) {
207     // Does nothing by default, but sub-classes can override.
208   }
209   virtual void ClassificationAddTotalExample(int64 int_label, float weight) = 0;
210 
211   virtual void ClassificationAddSplitStats() = 0;
212   virtual void ClassificationRemoveSplitStats(int split) = 0;
213 
AddSplitStats(const InputTarget * target,int example)214   void AddSplitStats(const InputTarget* target, int example) override {
215     if (left_gini_ != nullptr) {
216       left_gini_->add_split();
217       right_gini_->add_split();
218     }
219     if (params_.initialize_average_splits()) {
220       if (splits_[splits_.size() - 1].has_inequality_left_child_test()) {
221         half_initialized_splits_[splits_.size() - 1] =
222             target->GetTargetAsClassIndex(example, 0);
223       }
224     }
225     ClassificationAddSplitStats();
226   }
RemoveSplitStats(int split)227   void RemoveSplitStats(int split) override {
228     if (left_gini_ != nullptr) {
229       left_gini_->remove_split(split);
230       right_gini_->remove_split(split);
231     }
232     ClassificationRemoveSplitStats(split);
233   }
234 
235   // Virtual so we can override these to test.
236   virtual void CheckFinishEarly();
237   virtual void CheckFinishEarlyHoeffding();
238   virtual void CheckFinishEarlyBootstrap();
239 
240   virtual void CheckPrune();
241 
242   // Implement SplitPruningStrategyType::SPLIT_PRUNE_HOEFFDING.
243   void CheckPruneHoeffding();
244 
245   // Return the gini score, possibly being calculated from sums and squares
246   // saved in left_gini_ and right_gini_, otherwise calculated from raw counts.
247   float MaybeCachedGiniScore(int split, float* left_sum,
248                              float* right_sum) const;
249 
250   // Initialize the sum and squares of left_gini_ and right_gini_ for given
251   // split and value (being extracted from a proto), if left_gini_ isn't null.
MaybeInitializeRunningCount(int split,float val)252   void MaybeInitializeRunningCount(int split, float val) {
253     if (left_gini_ != nullptr) {
254       left_gini_->update(split, 0, val);
255       right_gini_->update(split, 0, val);
256     }
257   }
258 
259   int NumBootstrapSamples() const;
260 
261   // Populate *weights with the smoothed per-class frequencies needed to
262   // initialize a DistributionSampler.
263   void MakeBootstrapWeights(int index, std::vector<float>* weights);
264 
265   // Accessors for RunningGiniScores objects, for testing.
get_left_gini()266   virtual const std::unique_ptr<RunningGiniScores>& get_left_gini() const {
267     return left_gini_;
268   }
get_right_gini()269   virtual const std::unique_ptr<RunningGiniScores>& get_right_gini() const {
270     return right_gini_;
271   }
272 
273  private:
274   // Tracks how many check_every_samples epochs we've seen go by in weight_sum.
275   int32 finish_sample_epoch_;
276   int32 finish_check_every_;
277   int32 prune_sample_epoch_;
278   int32 prune_check_every_;
279   bool finish_early_;
280   int32 min_split_samples_;
281   float dominate_fraction_;
282   float prune_fraction_;
283 
284   // When using SPLIT_PRUNE_HOEFFDING, we precompute and store
285   // 0.5 * ln(1 / (1.0 - dominate_fraction_)).
286   float half_ln_dominate_frac_;
287 
288   std::unique_ptr<random::PhiloxRandom> single_rand_;
289   std::unique_ptr<random::SimplePhilox> rng_;
290 
291   std::unique_ptr<RunningGiniScores> left_gini_;
292   std::unique_ptr<RunningGiniScores> right_gini_;
293 
294   // Stores split number -> class that was first seen.
295   std::unordered_map<int, int32> half_initialized_splits_;
296 };
297 
298 // Tracks classification stats by storing class counts densely.
299 class DenseClassificationGrowStats : public ClassificationStats {
300  public:
DenseClassificationGrowStats(const TensorForestParams & params,int32 depth)301   DenseClassificationGrowStats(const TensorForestParams& params, int32 depth)
302       : ClassificationStats(params, depth) {}
303 
Initialize()304   void Initialize() override {
305     Clear();
306     total_counts_.resize(num_outputs_);
307   }
308 
309   void ExtractFromProto(const FertileSlot& slot) override;
310   void PackToProto(FertileSlot* slot) const override;
311 
312   void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
313                           LeafStat* right_stats) const override;
314 
315  protected:
ClassificationAddSplitStats()316   void ClassificationAddSplitStats() override {
317     left_counts_.resize(num_outputs_ * num_splits());
318   }
ClassificationRemoveSplitStats(int split_num)319   void ClassificationRemoveSplitStats(int split_num) override {
320     left_counts_.erase(left_counts_.begin() + num_outputs_ * split_num,
321                        left_counts_.begin() + num_outputs_ * (split_num + 1));
322   }
ClearInternal()323   void ClearInternal() override {
324     total_counts_.clear();
325     left_counts_.clear();
326     num_outputs_seen_ = 0;
327   }
328 
is_pure()329   bool is_pure() const override { return num_outputs_seen_ <= 1; }
330 
ClassificationAddLeftExample(int split,int64 int_label,float weight)331   void ClassificationAddLeftExample(int split, int64 int_label,
332                                     float weight) override {
333     mutable_left_count(split, int_label) += weight;
334   }
ClassificationAddTotalExample(int64 int_label,float weight)335   void ClassificationAddTotalExample(int64 int_label, float weight) override {
336     num_outputs_seen_ += total_counts_[int_label] == 0 && weight > 0;
337     total_counts_[int_label] += weight;
338   }
339 
340   float GiniScore(int split, float* left_sum, float* right_sum) const override;
341 
left_count(int split,int class_num)342   float left_count(int split, int class_num) const override {
343     return left_counts_[split * num_outputs_ + class_num];
344   }
right_count(int split,int class_num)345   float right_count(int split, int class_num) const override {
346     return total_counts_[class_num] -
347            left_counts_[split * num_outputs_ + class_num];
348   }
349 
350  private:
mutable_left_count(int split,int class_num)351   inline float& mutable_left_count(int split, int class_num) {
352     return left_counts_[split * num_outputs_ + class_num];
353   }
354   // Total class counts seen at this leaf
355   std::vector<float> total_counts_;
356 
357   // Also track the number of classes seen for not splitting pure leaves.
358   int num_outputs_seen_;
359 
360   // Left-branch taken class counts at this leaf for each split.
361   // This is a flat vector for memory-performance reasons.
362   // left_counts_[i * num_outputs_ + j] has the j-th class count for split i.
363   std::vector<float> left_counts_;
364 };
365 
366 // Tracks classification stats by storing class counts sparsely.
367 class SparseClassificationGrowStats : public ClassificationStats {
368  public:
SparseClassificationGrowStats(const TensorForestParams & params,int32 depth)369   SparseClassificationGrowStats(const TensorForestParams& params, int32 depth)
370       : ClassificationStats(params, depth) {}
371 
Initialize()372   void Initialize() override { Clear(); }
373 
374   void ExtractFromProto(const FertileSlot& slot) override;
375   void PackToProto(FertileSlot* slot) const override;
376 
377   void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
378                           LeafStat* right_stats) const override;
379 
380  protected:
ClassificationAddSplitStats()381   void ClassificationAddSplitStats() override {
382     left_counts_.resize(num_splits());
383   }
ClassificationRemoveSplitStats(int split_num)384   void ClassificationRemoveSplitStats(int split_num) override {
385     left_counts_.erase(left_counts_.begin() + split_num,
386                        left_counts_.begin() + (split_num + 1));
387   }
ClearInternal()388   void ClearInternal() override {
389     total_counts_.clear();
390     left_counts_.clear();
391   }
392 
is_pure()393   bool is_pure() const override { return total_counts_.size() <= 1; }
394 
ClassificationAddLeftExample(int split,int64 int_label,float weight)395   void ClassificationAddLeftExample(int split, int64 int_label,
396                                     float weight) override {
397     left_counts_[split][int_label] += weight;
398   }
ClassificationAddTotalExample(int64 int_label,float weight)399   void ClassificationAddTotalExample(int64 int_label, float weight) override {
400     total_counts_[int_label] += weight;
401   }
402 
403   float GiniScore(int split, float* left_sum, float* right_sum) const override;
404 
left_count(int split,int class_num)405   float left_count(int split, int class_num) const override {
406     return left_counts_[split].at(class_num);
407   }
right_count(int split,int class_num)408   float right_count(int split, int class_num) const override {
409     return total_counts_.at(class_num) - left_counts_[split].at(class_num);
410   }
411 
412  private:
413   // Total class counts seen at this leaf
414   std::unordered_map<int, float> total_counts_;
415 
416   // Left-branch taken class counts at this leaf for each split.
417   // left_counts_[i][j] has the j-th class count for split i.
418   std::vector<std::unordered_map<int, float>> left_counts_;
419 };
420 
421 // Accumulates weights for the most popular classes while only using a
422 // fixed amount of space.
423 class FixedSizeClassStats {
424  public:
425   // n specifies how many classes are tracked.
FixedSizeClassStats(int n,int num_classes)426   FixedSizeClassStats(int n, int num_classes)
427       : n_(n), num_classes_(num_classes), smallest_weight_class_(-1) {}
428 
429   // Add weight w to the class c.
430   void accumulate(int c, float w);
431 
432   // Return the approximate accumulated weight for class c.  If c isn't one
433   // of the n-most popular classes, this can be 0 even if c has accumulated
434   // some weight.
435   float get_weight(int c) const;
436 
437   // Put the sum of all weights seen into *sum, and
438   // \sum_c get_weight(c)^2
439   // into *square.  *sum will be exact, but *square will be approximate.
440   void set_sum_and_square(float* sum, float* square) const;
441 
442   void ExtractFromProto(const decision_trees::SparseVector& sparse_vector);
443   void PackToProto(decision_trees::SparseVector* sparse_vector) const;
444 
445  private:
446   // For our typical use cases, n_ is between 10 and 100, so there's no
447   // need to track the smallest weight with a min_heap or the like.
448   int n_;
449   int num_classes_;
450 
451   // This tracks the class of the smallest weight, but isn't set until
452   // class_weights_.size() == n_.
453   int smallest_weight_class_;
454 
455   std::unordered_map<int, float> class_weights_;
456 };
457 
458 // Tracks classification stats sparsely in a fixed amount of space.
459 class FixedSizeSparseClassificationGrowStats : public ClassificationStats {
460  public:
FixedSizeSparseClassificationGrowStats(const TensorForestParams & params,int32 depth)461   FixedSizeSparseClassificationGrowStats(const TensorForestParams& params,
462                                          int32 depth)
463       : ClassificationStats(params, depth) {}
464 
Initialize()465   void Initialize() override { Clear(); }
466 
467   void ExtractFromProto(const FertileSlot& slot) override;
468   void PackToProto(FertileSlot* slot) const override;
469 
470   void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
471                           LeafStat* right_stats) const override;
472 
473  protected:
ClassificationAddSplitStats()474   void ClassificationAddSplitStats() override {
475     FixedSizeClassStats stats(params_.num_classes_to_track(),
476                               params_.num_outputs());
477     left_counts_.resize(num_splits(), stats);
478     right_counts_.resize(num_splits(), stats);
479   }
ClassificationRemoveSplitStats(int split_num)480   void ClassificationRemoveSplitStats(int split_num) override {
481     left_counts_.erase(left_counts_.begin() + split_num,
482                        left_counts_.begin() + (split_num + 1));
483     right_counts_.erase(right_counts_.begin() + split_num,
484                         right_counts_.begin() + (split_num + 1));
485   }
ClearInternal()486   void ClearInternal() override {
487     left_counts_.clear();
488     right_counts_.clear();
489   }
490 
is_pure()491   bool is_pure() const override { return first_two_classes_seen_.size() <= 1; }
492 
ClassificationAddLeftExample(int split,int64 int_label,float weight)493   void ClassificationAddLeftExample(int split, int64 int_label,
494                                     float weight) override {
495     left_counts_[split].accumulate(int_label, weight);
496   }
ClassificationAddRightExample(int split,int64 int_label,float weight)497   void ClassificationAddRightExample(int split, int64 int_label,
498                                      float weight) override {
499     right_counts_[split].accumulate(int_label, weight);
500   }
ClassificationAddTotalExample(int64 int_label,float weight)501   void ClassificationAddTotalExample(int64 int_label, float weight) override {
502     if (is_pure()) {
503       first_two_classes_seen_.insert(int_label);
504     }
505   }
506 
507   float GiniScore(int split, float* left_sum, float* right_sum) const override;
508 
left_count(int split,int class_num)509   float left_count(int split, int class_num) const override {
510     return left_counts_[split].get_weight(class_num);
511   }
512 
right_count(int split,int class_num)513   float right_count(int split, int class_num) const override {
514     return right_counts_[split].get_weight(class_num);
515   }
516 
517  private:
518   std::vector<FixedSizeClassStats> left_counts_;
519   std::vector<FixedSizeClassStats> right_counts_;
520 
521   // We keep track of the first two class labels seen, so we can tell if
522   // the node is pure (= all of one class) or not.
523   std::set<int> first_two_classes_seen_;
524 };
525 
526 // Tracks regression stats using least-squares minimization.
527 class LeastSquaresRegressionGrowStats : public GrowStats {
528  public:
LeastSquaresRegressionGrowStats(const TensorForestParams & params,int32 depth)529   LeastSquaresRegressionGrowStats(const TensorForestParams& params, int32 depth)
530       : GrowStats(params, depth) {}
531 
Initialize()532   void Initialize() override {
533     Clear();
534     total_sum_.resize(num_outputs_);
535     total_sum_squares_.resize(num_outputs_);
536   }
537 
538   void ExtractFromProto(const FertileSlot& slot) override;
539   void PackToProto(FertileSlot* slot) const override;
540 
541   void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
542                   const InputTarget* target, int example) override;
543   bool BestSplit(SplitCandidate* best) const override;
544   bool IsFinished() const override;
545 
546  protected:
547   // Returns the variance of split.
548   float SplitVariance(int split) const;
549 
AddSplitStats(const InputTarget * target,int example)550   void AddSplitStats(const InputTarget* target, int example) override {
551     left_sums_.resize(num_outputs_ * num_splits());
552     left_squares_.resize(num_outputs_ * num_splits());
553     left_counts_.push_back(0);
554   }
RemoveSplitStats(int split_num)555   void RemoveSplitStats(int split_num) override {
556     left_sums_.erase(left_sums_.begin() + num_outputs_ * split_num,
557                      left_sums_.begin() + num_outputs_ * (split_num + 1));
558     left_squares_.erase(left_squares_.begin() + num_outputs_ * split_num,
559                         left_squares_.begin() + num_outputs_ * (split_num + 1));
560     left_counts_.erase(left_counts_.begin() + split_num,
561                        left_counts_.begin() + (split_num + 1));
562   }
563 
ClearInternal()564   void ClearInternal() override {
565     total_sum_.clear();
566     total_sum_squares_.clear();
567     left_sums_.clear();
568     left_squares_.clear();
569   }
570 
571  private:
572   // Convenience methods for accessing the flat count vectors.
left_sum(int split,int output_num)573   inline const float& left_sum(int split, int output_num) const {
574     return left_sums_[split * num_outputs_ + output_num];
575   }
left_sum(int split,int output_num)576   inline float& left_sum(int split, int output_num) {
577     return left_sums_[split * num_outputs_ + output_num];
578   }
left_square(int split,int output_num)579   inline const float& left_square(int split, int output_num) const {
580     return left_squares_[split * num_outputs_ + output_num];
581   }
left_square(int split,int output_num)582   inline float& left_square(int split, int output_num) {
583     return left_squares_[split * num_outputs_ + output_num];
584   }
585 
586   // Total sums and squares seen at this leaf.
587   // sum[i] is the sum of the i-th output.
588   std::vector<float> total_sum_;
589   std::vector<float> total_sum_squares_;
590 
591   // Per-split sums and squares, stored flat for performance.
592   // left_sums_[i * num_outputs_ + j] has the j-th sum for split i.
593   std::vector<float> left_sums_;
594   std::vector<float> left_squares_;
595 
596   // The number of example seen at each split.
597   std::vector<int64> left_counts_;
598 };
599 
600 }  // namespace tensorforest
601 }  // namespace tensorflow
602 
603 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
604