1syntax = "proto3";
2option cc_enable_arenas = true;
3
4package tensorflow.tensorforest;
5
6import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto";
7
8
9message FertileStats {
10  // Tracks stats for each node.  node_to_slot[i] is the FertileSlot for node i.
11  // This may be sized to max_nodes initially, or grow dynamically as needed.
12  repeated FertileSlot node_to_slot = 1;
13}
14
15
16message GiniStats {
17  // This allows us to quickly track and calculate impurity (classification)
18  //  by storing the sum of input weights and the sum of the squares of the
19  // input weights.  Weighted gini is then: 1 - (square / sum * sum).
20  // Updates to these numbers are:
21  //   old_i = leaf->value(label)
22  //   new_i = old_i + incoming_weight
23  //   sum -> sum + incoming_weight
24  //   square -> square - (old_i ^ 2) + (new_i ^ 2)
25  //   total_left_sum -> total_left_sum - old_left_i * old_total_i +
26  //                                      new_left_i * new_total_i
27  float square = 2;
28}
29
30message LeafStat {
31  // The sum of the weights of the training examples that we have seen.
32  // This is here, outside of the leaf_stat oneof, because almost all
33  // types will want it.
34  float weight_sum = 3;
35
36  // TODO(thomaswc): Move the GiniStats out of LeafStats and into something
37  // that only tracks them for splits.
38  message GiniImpurityClassificationStats {
39    oneof counts {
40      decision_trees.Vector dense_counts = 1;
41      decision_trees.SparseVector sparse_counts = 2;
42    }
43    GiniStats gini = 3;
44  }
45
46  // This is the info needed for calculating variance for regression.
47  // Variance will still have to be summed over every output, but the
48  // number of outputs in regression problems is almost always 1.
49  message LeastSquaresRegressionStats {
50    decision_trees.Vector mean_output = 1;
51    decision_trees.Vector mean_output_squares = 2;
52  }
53
54  oneof leaf_stat {
55    GiniImpurityClassificationStats classification = 1;
56    LeastSquaresRegressionStats regression = 2;
57    // TODO(thomaswc): Add in v5's SparseClassStats.
58  }
59}
60
61message FertileSlot {
62  // The statistics for *all* the examples seen at this leaf.
63  LeafStat leaf_stats = 4;
64
65  repeated SplitCandidate candidates = 1;
66
67  // The statistics for the examples seen at this leaf after all the
68  // splits have been initialized.  If post_init_leaf_stats.weight_sum
69  // is > 0, then all candidates have been initialized.  We need to track
70  // both leaf_stats and post_init_leaf_stats because the first is used
71  // to create the decision_tree::Leaf and the second is used to infer
72  // the statistics for the right side of a split (given the leaf side
73  // stats).
74  LeafStat post_init_leaf_stats = 6;
75
76  int32 node_id = 5;
77  int32 depth = 7;
78}
79
80message SplitCandidate {
81  // proto representing the potential node.
82  decision_trees.BinaryNode split = 1;
83
84  // Right counts are inferred from FertileSlot.leaf_stats and left.
85  LeafStat left_stats = 4;
86
87  // Right stats (not full counts) are kept here.
88  LeafStat right_stats = 5;
89
90  // Fields used when training with a graph runner.
91  string unique_id = 6;
92}
93
94// Proto used for tracking tree paths during inference time.
95message TreePath {
96  // Nodes are listed in order that they were traversed. i.e. nodes_visited[0]
97  // is the tree's root node.
98  repeated decision_trees.TreeNode nodes_visited = 1;
99}
100