1syntax = "proto3";
2
3package tensorflow.tpu;
4
5import "google/protobuf/wrappers.proto";
6import "tensorflow/compiler/xla/service/hlo.proto";
7
8message ClippingLimits {
9  google.protobuf.FloatValue lower = 1;  // -inf if not set
10  google.protobuf.FloatValue upper = 2;  // +inf if not set
11}
12
13// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The
14// actual learning rates are provided as a scalar input list to the
15// SendTPUEmbeddingGradients Op indexed by their tag specified through the
16// following proto.
17message DynamicLearningRate {
18  // For tables where learning rates are dynamically computed and communicated
19  // to the TPU embedding program, a tag must be specified for the learning
20  // rate.
21  //
22  // The tag must be a non-negative  integer. The total number of unique tags
23  // must be less than or equal to the number of tables in the TPU embedding
24  // configuration (a table does not specify any tag if it uses a constant
25  // learning rate, and specifies exactly one tag if it uses dynamic learning
26  // rates).
27  //
28  // All tags in the range [0, number_of_unique_tags) must be present in the TPU
29  // embedding configuration, i.e. a tag cannot be skipped if a different tag
30  // numerically greater than it is used in the configuration.
31  //
32  // If multiple tables specify the same tag, they *MUST* have
33  // the same dynamic learning rate, for example, their dynamic learning rate
34  // could be computed by the same TensorFlow sub-graph. The partitioning of the
35  // embedding layer would be more optimal if the number_of_unique_tags is as
36  // *LOW* as possible, i.e., if many tables share the same tag.
37  //
38  // The learning_rate input of the SendTPUEmbeddingGradients op is used to
39  // communicate dynamic learning rates to the TPU embedding program.
40  // The learning_rate input is a list of scalars where the size of the list is
41  // equal to the number of unique tags. The learning rate associated with a
42  // particular tag is specified by populating its corresponding index in the
43  // list of learning_rate scalars.
44  int32 tag = 1;
45}
46
47// Source of learning rate to use.
48message LearningRate {
49  oneof learning_rate {
50    float constant = 1;
51    DynamicLearningRate dynamic = 2;
52  }
53}
54
55// Each optimizer's parameter proto has a link to its documentation and CPU
56// implementation (if available) for user reference.
57
58// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adagrad
59// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1634
60message AdagradParameters {
61  // Old initial accumulator parameter.
62  reserved "initial_accumulator";
63  reserved 1;
64}
65
66// Algorithm in http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
67message BoundedAdagradParameters {
68  // Whether to use the updated or the old value of the accumulator when
69  // computing the effective learning rate. When update_accumulator_first is set
70  // to True, the updated value of the accumulator is used.
71  bool update_accumulator_first = 1;
72  // The max_var_update value to use. Set value to 0 (default) to disable using
73  // max_var_update to clip the gradient.
74  float max_var_update = 2;
75  // The maximum value of the accumulator. Set max_accumulator to 0 (default)
76  // to disable using max_accumulator to clip the accumulator.
77  float max_accumulator = 3;
78}
79
80// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD
81// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L629
82message StochasticGradientDescentParameters {}
83
84// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
85// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf
86// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L2646
87//
88// The hyperparameters for FTRL are the same as for the Keras implementation,
89// with some additions. The "beta" parameter matches the behavior described in
90// the second link above; "beta" / (2 * learning rate) should be added to "l2"
91// to get equivalent behavior in the other TensorFlow implementations of this
92// optimizer. When the multiply_linear_by_lr field is set to true, a modified
93// formula is used for FTRL that treats the "linear" accumulator as being
94// pre-multiplied by the learning rate (i.e., the accumulator named "linear"
95// actually stores "linear * learning_rate"). Other than checkpoint
96// compatibility, this is mathematically equivalent for a static learning rate;
97// for a dynamic learning rate, it is nearly the same as long as the learning
98// rate does not change quickly. The benefit of setting multiply_linear_by_lr to
99// true is that the modified formula handles zero and near-zero learning rates
100// without producing NaNs, improving flexibility for learning rate ramp-up. The
101// allow_zero_accumulator parameter changes some internal formulas to allow zero
102// and near-zero accumulator values at the cost of some performance; this only
103// needs to be set if you are using an initial accumulator value of zero, which
104// is uncommon.
105message FtrlParameters {
106  float l1 = 1;
107  float l2 = 2;
108  float lr_power = 3;
109  float beta = 7;
110  bool multiply_linear_by_lr = 6;
111  bool allow_zero_accumulator = 8;
112
113  // Old initial accumulator parameters.
114  reserved "initial_accum", "initial_linear";
115  reserved 4, 5;
116}
117
118// The Adam optimizer does not implement hyper-parameter update due to hardware
119// limitations; use the dynamic learning rate feature instead, setting the
120// learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
121// Here, t is the current timestep.
122//
123// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam
124// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L32
125//
126// Note that the code by default implements the lazy version of Adam
127// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer)
128// unless the use_non_lazy_adam parameter is set, in which case it implements
129// the normal version of Adam that updates all parameters in the embedding
130// table, even for entries that are not used in the current minibatch
131// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
132// use_non_lazy_adam is enabled, gradient accumulation is also required to be
133// enabled in order to get correct results; a warning will be printed otherwise
134// (which may change to an error in the future). If use_sum_inside_sqrt is set,
135// the Adam variable update formula will be changed from m / (sqrt(v) + epsilon)
136// to m / sqrt(v + epsilon**2); this option improves the performance of TPU
137// training and is not expected to harm model quality.
138message AdamParameters {
139  float beta1 = 3;
140  float beta2 = 4;
141  float epsilon = 5;
142  bool use_non_lazy_adam = 8;
143  bool use_sum_inside_sqrt = 10;
144
145  // Old initial accumulator parameters.
146  reserved "initial_m", "initial_v";
147  reserved 6, 7;
148}
149
150// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD
151// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L3068
152message MomentumParameters {
153  float momentum = 1;
154  bool use_nesterov = 2;
155
156  // Old initial accumulator parameter.
157  reserved "initial_accum";
158  reserved 3;
159}
160
161// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop
162// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4229
163message RmsPropParameters {
164  float rho = 1;
165  float momentum = 2;
166  float epsilon = 3;
167
168  // Old initial accumulator parameters.
169  reserved "initial_ms", "initial_mom";
170  reserved 4, 5;
171}
172
173// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop
174// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4358
175message CenteredRmsPropParameters {
176  float rho = 1;
177  float momentum = 2;
178  float epsilon = 3;
179
180  // Old initial accumulator parameters.
181  reserved "initial_ms", "initial_mom", "initial_mg";
182  reserved 4, 5, 6;
183}
184
185// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf
186message MdlAdagradLightParameters {
187  float l2 = 1;
188  float lr_power = 2;
189  float min_servable_mdl_benefit = 3;
190  float mdl_mix_in_margin = 4;
191  float mdl_benefit_rampup_coeff = 5;
192  float mdl_min_weight = 6;
193  float benefit_revisit_scale = 7;
194  float max_event_benefit = 8;
195  float max_total_benefit = 9;
196  float mdl_hard_limit = 10;
197  bool hard_limit_min_benefit = 11;
198  bool mdl_regularize = 12;
199
200  // Old initial accumulator parameters.
201  reserved "initial_accumulator", "initial_weight", "initial_benefit";
202  reserved 13, 14, 15;
203}
204
205// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adadelta
206// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L933
207message AdadeltaParameters {
208  float rho = 1;
209  float epsilon = 2;
210
211  // Old initial accumulator parameters.
212  reserved "initial_accumulator", "initial_update";
213  reserved 3, 4;
214}
215
216// https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/ProximalAdagradOptimizer
217// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1961
218message ProximalAdagradParameters {
219  float l1 = 1;
220  float l2 = 2;
221
222  // Old initial accumulator parameter.
223  reserved "initial_accumulator";
224  reserved 3;
225}
226
227// The online Yogi optimizer does not implement hyper-parameter update; use the
228// dynamic learning rate feature instead, setting the learning rate to:
229// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
230// Here, t is the current timestep.
231//
232// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
233// plus some extensions based on FTRL.
234//
235// Note that the code by default implements the lazy version of online Yogi.
236message OnlineYogiParameters {
237  // The L1 regularization parameter (used analogously to the one in FTRL).
238  float l1 = 1;
239
240  // The L2 regularization parameter (used analogously to the one in FTRL).
241  float l2 = 2;
242
243  // \beta_2 from Algorithm 2 in the paper.
244  float beta2 = 3;
245
246  // Reserved ids corresponding to removed tanh activation.
247  reserved 6;  // sign
248  reserved 7;  // tanh
249}
250
251// The online Yogi optimizer does not implement hyper-parameter update; use the
252// dynamic learning rate feature instead, setting the learning rate to:
253// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
254// Here, t is the current timestep.
255//
256// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
257// plus some extensions based on FTRL.
258//
259// Note that the code by default implements the lazy version of proximal Yogi.
260message ProximalYogiParameters {
261  // The L1 regularization parameter.
262  float l1 = 1;
263
264  // The L2 regularization parameter.
265  float l2 = 2;
266
267  // The exponential decay rate for the 1st moment estimates.
268  float beta1 = 3;
269
270  // The exponential decay rate for the 2nd moment estimates.
271  float beta2 = 4;
272
273  // A constant trading off adaptivity and noise.
274  float epsilon = 5;
275
276  // Reserved ids corresponding to removed tanh activation.
277  reserved 8;  // sign
278  reserved 9;  // tanh
279}
280
281// Estimator for the frequency of updates to a lookup table. It maintains an
282// array (tf.Variable) D, where each element records the average number of
283// global steps between two consecutive batches that hit the corresponding
284// bucket. Once an item with bucket id i is sampled, D[i] is updated by:
285//   D[i] <- D[i] * (1 - tau) + delta[i] * tau,
286//
287// where tau is a learning rate between 0 and 1 (exclusive), and
288//   delta[i] = current global step - last step i is sampled.
289//
290// The estimated frequency (sampling rate in a batch) is thus 1 / D[i].
291//
292// Elements in D are initialized with a large value max_delta. delta[i] will
293// also be capped by this value.
294//
295// The exact sequence of operations used in the optimizer is shown below.
296// last_hit_step[i] is a tf.Variable that holds the last global step at which i
297// was sampled.
298//
299//   delta = global_step - last_hit_step[i]
300//   clipped_delta = min(delta, params.max_delta)
301//   is_outlier = (delta >= params.outlier_threshold * D[i])
302//   D[i] <- is_outlier ? clipped_delta
303//                      : D[i] * (1 - params.tau) + clipped_delta * params.tau
304//   last_hit_step[i] <- global_step
305message FrequencyEstimatorParameters {
306  // Learning rate between (0, 1) that is used to update the array D.
307  float tau = 1;
308
309  // Maximum value of delta: difference between the current global step and the
310  // last global step at which the row was sampled.
311  float max_delta = 2;
312
313  // Threshold used to determine whether the current update is an outlier.
314  float outlier_threshold = 3;
315
316  // The weight exponent used to transform the estimated delta into weights.
317  // The transformation function is: (delta / max_delta) ^ (weight_exponent)
318  float weight_exponent = 4;
319}
320
321// A user-defined optimizer.
322// The contained HLO program must take the following arguments in the following
323// order:
324// 1.  gradients
325// 2.  table weights
326// 3.  slot variables
327// 4.  an optional scalar input that is passed in via the dynamic learning
328//     rate mechanism.
329//
330// It must return/end in a tuple op that contains the following values in the
331// following order:
332// 1.  new table values
333// 2.  new slot variable value
334//
335// The program must have shape (1,1) with dtype float32 throughout and only use
336// HLO that operate elementwise (e.g., no reduce, no variables, no control flow
337// and no broadcasting outside of the single scalar input).
338// The HLO program should be written as if it were a dense update. It will be
339// called on each row that needs an update and will applied elementwise.
340message UserDefinedProgramParameters {
341  xla.HloModuleProto program = 1;
342  // Padding values for the parameter and the slots, see
343  // StateVariableSpecification.padding_initial_value below for more details on
344  // how this should be set. One value is needed for the weights and one for
345  // each slot.
346  repeated float padding_values = 2;
347}
348
349// Status of using gradient accumulation (doing two passes over the input
350// gradients: one to accumulate them into a temporary array and another to apply
351// them using the actual optimization algorithm). The extra message is to wrap
352// the enum for scoping.
353message GradientAccumulationStatus {
354  // if UNSPECIFIED (default), gradient accumulation is ENABLED.
355  enum Status {
356    UNSPECIFIED = 0;
357    ENABLED = 1;
358    DISABLED = 2;
359  }
360}
361
362// Configuration proto for hot ID optimization. This is an experimental feature
363// that is currently disabled (by default).
364message HotIdReplicationConfiguration {
365  // Whether to enable or disable hot ID optimization.
366  // If UNSPECIFIED (default), hot ID optimization is DISABLED.
367  enum Status {
368    UNSPECIFIED = 0;
369    ENABLED = 1;
370    DISABLED = 2;
371  }
372  Status status = 1;
373}
374
375message OptimizationParameters {
376  // Learning rate used for updating the embedding layer parameters.
377  LearningRate learning_rate = 13;
378  reserved 1;  // Old learning rate tag.
379
380  // Limits to which to clip the weight values after the backward pass; not
381  // present means no limits are applied.
382  ClippingLimits clipping_limits = 2;
383
384  // Limits to which to clip the backward pass gradient before using it for
385  // updates; not present means no limits are applied.
386  ClippingLimits gradient_clipping_limits = 7;
387
388  // Amount of weight decay to apply; see weight_decay_optimizers.py for
389  // details. Almost all optimizers are supported with this option (MDL Adagrad
390  // Light does not work, and SGD does not behave as expected if it is enabled).
391  // Although there is no check, users who want weight decay will probably also
392  // want to enable gradient accumulation as well so that the decay will happen
393  // once per minibatch.
394  float weight_decay_factor = 16;
395
396  // If true, the weight decay factor is multiplied by the current learning rate
397  // before use; this is to match the note in DecoupledWeightDecayExtension in
398  // weight_decay_optimizers.py.
399  bool multiply_weight_decay_factor_by_learning_rate = 22;
400
401  // Status of using gradient accumulation (doing two passes over the input
402  // gradients: one to accumulate them into a temporary array and another to
403  // apply them using the actual optimization algorithm).
404  GradientAccumulationStatus.Status gradient_accumulation_status = 17;
405
406  // Configuration proto for hot ID replication. This is an experimental
407  // feature that is currently disabled (by default).
408  HotIdReplicationConfiguration hot_id_replication_configuration = 18;
409
410  // Optimization algorithm parameters; which field is selected determines which
411  // algorithm to use.
412  oneof parameters {
413    AdagradParameters adagrad = 3;
414    BoundedAdagradParameters bounded_adagrad = 19;
415    StochasticGradientDescentParameters stochastic_gradient_descent = 4;
416    FtrlParameters ftrl = 5;
417    AdamParameters adam = 6;
418    MomentumParameters momentum = 8;
419    RmsPropParameters rms_prop = 9;
420    CenteredRmsPropParameters centered_rms_prop = 10;
421    MdlAdagradLightParameters mdl_adagrad_light = 11;
422    AdadeltaParameters adadelta = 12;
423    ProximalAdagradParameters proximal_adagrad = 14;
424    OnlineYogiParameters online_yogi = 20;
425    ProximalYogiParameters proximal_yogi = 21;
426    FrequencyEstimatorParameters frequency_estimator = 23;
427    UserDefinedProgramParameters user_defined_program = 24;
428  }
429
430  reserved 15;  // Old use_gradient_accumulation.
431}
432
433// Specification of an optimization algorithm's state variables (both the main
434// value vector and any extra accumulators, etc.). This proto is only used
435// internally by the TPU software and is not exposed directly to the TF model.
436message StateVariableSpecification {
437  // Parameter name for the state variable.
438  string name = 1;
439
440  // A normal state variable that should be saved and restored in checkpoints
441  // and used as an input or output to non-debug TensorFlow ops.
442  message UserDefined {
443    // For padding embedding rows, this field specifies the initial value to be
444    // used. Separate initial values need to be specified for the embeddings and
445    // any extra accumulators. The initial values should be specified so as to
446    // maintain two invariants during model training:
447    // (1) The embedding vector multiplied by zero returns a vector containing
448    //     all zeros. To maintain this invariant, the embedding values should
449    //     never be NaNs or +-infinity.
450    // (2) Repeatedly applying the optimizer using a gradient vector of all
451    //     zeros does not cause the embeddings or slot variables to become NaNs
452    //     or +-infinity.
453    // The padding row is looked up when no embedding IDs are present for a
454    // feature. The semantics of embedding lookup dictate that the output must
455    // be zero under this scenario.
456    double padding_initial_value = 1;
457  }
458
459  // A state variable that should be filled with a constant and normally hidden
460  // from users (used for intermediate gradients being accumulated, for
461  // example).
462  message FillWithConstant {
463    double initial_value = 1;
464  }
465
466  // Usage type of this state variable.
467  oneof usage {
468    UserDefined user_defined = 2;
469    FillWithConstant fill_with_constant = 3;
470  }
471}
472