1syntax = "proto3";
2
3option cc_enable_arenas = true;
4
5package tensorflow.boosted_trees.learner;
6
7// Tree regularization config.
8message TreeRegularizationConfig {
9  // Classic L1/L2.
10  float l1 = 1;
11  float l2 = 2;
12
13  // Tree complexity penalizes overall model complexity effectively
14  // limiting how deep the tree can grow in regions with small gain.
15  float tree_complexity = 3;
16}
17
18// Tree constraints config.
19message TreeConstraintsConfig {
20  // Maximum depth of the trees. The default value is 6 if not specified.
21  uint32 max_tree_depth = 1;
22
23  // Min hessian weight per node.
24  float min_node_weight = 2;
25
26  // Maximum number of unique features used in the tree. Zero means there is no
27  // limit.
28  int64 max_number_of_unique_feature_columns = 3;
29}
30
31// LearningRateConfig describes all supported learning rate tuners.
32message LearningRateConfig {
33  oneof tuner {
34    LearningRateFixedConfig fixed = 1;
35    LearningRateDropoutDrivenConfig dropout = 2;
36    LearningRateLineSearchConfig line_search = 3;
37  }
38}
39
40// Config for a fixed learning rate.
41message LearningRateFixedConfig {
42  float learning_rate = 1;
43}
44
45// Config for a tuned learning rate.
46message LearningRateLineSearchConfig {
47  // Max learning rate. Must be strictly positive.
48  float max_learning_rate = 1;
49
50  // Number of learning rate values to consider between [0, max_learning_rate).
51  int32 num_steps = 2;
52}
53
54// When we have a sequence of trees 1, 2, 3 ... n, these essentially represent
55// weights updates in functional space, and thus we can use averaging of weight
56// updates to achieve better performance. For example, we can say that our final
57// ensemble will be an average of ensembles of tree 1, and ensemble of tree 1
58// and tree 2 etc .. ensemble of all trees.
59// Note that this averaging will apply ONLY DURING PREDICTION. The training
60// stays the same.
61message AveragingConfig {
62  oneof config {
63    float average_last_n_trees = 1;
64    // Between 0 and 1. If set to 1.0, we are averaging ensembles of tree 1,
65    // ensemble of tree 1 and tree 2, etc ensemble of all trees. If set to 0.5,
66    // last half of the trees are averaged etc.
67    float average_last_percent_trees = 2;
68  }
69}
70
71message LearningRateDropoutDrivenConfig {
72  // Probability of dropping each tree in an existing so far ensemble.
73  float dropout_probability = 1;
74
75  // When trees are built after dropout happen, they don't "advance" to the
76  // optimal solution, they just rearrange the path. However you can still
77  // choose to skip dropout periodically, to allow a new tree that "advances"
78  // to be added.
79  // For example, if running for 200 steps with probability of dropout 1/100,
80  // you would expect the dropout to start happening for sure for all iterations
81  // after 100. However you can add probability_of_skipping_dropout of 0.1, this
82  // way iterations 100-200 will include approx 90 iterations of dropout and 10
83  // iterations of normal steps.Set it to 0 if you want just keep building
84  // the refinement trees after dropout kicks in.
85  float probability_of_skipping_dropout = 2;
86
87  // Between 0 and 1.
88  float learning_rate = 3;
89}
90
91message LearnerConfig {
92  enum PruningMode {
93    PRUNING_MODE_UNSPECIFIED = 0;
94    PRE_PRUNE = 1;
95    POST_PRUNE = 2;
96  }
97
98  enum GrowingMode {
99    GROWING_MODE_UNSPECIFIED = 0;
100    WHOLE_TREE = 1;
101    LAYER_BY_LAYER = 2;
102  }
103
104  enum MultiClassStrategy {
105    MULTI_CLASS_STRATEGY_UNSPECIFIED = 0;
106    TREE_PER_CLASS = 1;
107    FULL_HESSIAN = 2;
108    DIAGONAL_HESSIAN = 3;
109  }
110
111  enum WeakLearnerType {
112    NORMAL_DECISION_TREE = 0;
113    OBLIVIOUS_DECISION_TREE = 1;
114  }
115
116  // Number of classes.
117  uint32 num_classes = 1;
118
119  // Fraction of features to consider in each tree sampled randomly
120  // from all available features.
121  oneof feature_fraction {
122    float feature_fraction_per_tree = 2;
123    float feature_fraction_per_level = 3;
124  };
125
126  // Regularization.
127  TreeRegularizationConfig regularization = 4;
128
129  // Constraints.
130  TreeConstraintsConfig constraints = 5;
131
132  // Pruning. POST_PRUNE is the default pruning mode.
133  PruningMode pruning_mode = 8;
134
135  // Growing Mode. LAYER_BY_LAYER is the default growing mode.
136  GrowingMode growing_mode = 9;
137
138  // Learning rate. By default we use fixed learning rate of 0.1.
139  LearningRateConfig learning_rate_tuner = 6;
140
141  // Multi-class strategy. By default we use TREE_PER_CLASS for binary
142  // classification and linear regression. For other cases, we use
143  // DIAGONAL_HESSIAN as the default.
144  MultiClassStrategy multi_class_strategy = 10;
145
146  // If you want to average the ensembles (for regularization), provide the
147  // config below.
148  AveragingConfig averaging_config = 11;
149
150  // By default we use NORMAL_DECISION_TREE as weak learner.
151  WeakLearnerType weak_learner_type = 12;
152}
153