1syntax = "proto3"; 2option cc_enable_arenas = true; 3 4package tensorflow.tensorforest; 5 6import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto"; 7 8 9message FertileStats { 10 // Tracks stats for each node. node_to_slot[i] is the FertileSlot for node i. 11 // This may be sized to max_nodes initially, or grow dynamically as needed. 12 repeated FertileSlot node_to_slot = 1; 13} 14 15 16message GiniStats { 17 // This allows us to quickly track and calculate impurity (classification) 18 // by storing the sum of input weights and the sum of the squares of the 19 // input weights. Weighted gini is then: 1 - (square / sum * sum). 20 // Updates to these numbers are: 21 // old_i = leaf->value(label) 22 // new_i = old_i + incoming_weight 23 // sum -> sum + incoming_weight 24 // square -> square - (old_i ^ 2) + (new_i ^ 2) 25 // total_left_sum -> total_left_sum - old_left_i * old_total_i + 26 // new_left_i * new_total_i 27 float square = 2; 28} 29 30message LeafStat { 31 // The sum of the weights of the training examples that we have seen. 32 // This is here, outside of the leaf_stat oneof, because almost all 33 // types will want it. 34 float weight_sum = 3; 35 36 // TODO(thomaswc): Move the GiniStats out of LeafStats and into something 37 // that only tracks them for splits. 38 message GiniImpurityClassificationStats { 39 oneof counts { 40 decision_trees.Vector dense_counts = 1; 41 decision_trees.SparseVector sparse_counts = 2; 42 } 43 GiniStats gini = 3; 44 } 45 46 // This is the info needed for calculating variance for regression. 47 // Variance will still have to be summed over every output, but the 48 // number of outputs in regression problems is almost always 1. 49 message LeastSquaresRegressionStats { 50 decision_trees.Vector mean_output = 1; 51 decision_trees.Vector mean_output_squares = 2; 52 } 53 54 oneof leaf_stat { 55 GiniImpurityClassificationStats classification = 1; 56 LeastSquaresRegressionStats regression = 2; 57 // TODO(thomaswc): Add in v5's SparseClassStats. 58 } 59} 60 61message FertileSlot { 62 // The statistics for *all* the examples seen at this leaf. 63 LeafStat leaf_stats = 4; 64 65 repeated SplitCandidate candidates = 1; 66 67 // The statistics for the examples seen at this leaf after all the 68 // splits have been initialized. If post_init_leaf_stats.weight_sum 69 // is > 0, then all candidates have been initialized. We need to track 70 // both leaf_stats and post_init_leaf_stats because the first is used 71 // to create the decision_tree::Leaf and the second is used to infer 72 // the statistics for the right side of a split (given the leaf side 73 // stats). 74 LeafStat post_init_leaf_stats = 6; 75 76 int32 node_id = 5; 77 int32 depth = 7; 78} 79 80message SplitCandidate { 81 // proto representing the potential node. 82 decision_trees.BinaryNode split = 1; 83 84 // Right counts are inferred from FertileSlot.leaf_stats and left. 85 LeafStat left_stats = 4; 86 87 // Right stats (not full counts) are kept here. 88 LeafStat right_stats = 5; 89 90 // Fields used when training with a graph runner. 91 string unique_id = 6; 92} 93 94// Proto used for tracking tree paths during inference time. 95message TreePath { 96 // Nodes are listed in order that they were traversed. i.e. nodes_visited[0] 97 // is the tree's root node. 98 repeated decision_trees.TreeNode nodes_visited = 1; 99} 100