1syntax = "proto3"; 2 3package tensorflow.tpu; 4 5import "google/protobuf/wrappers.proto"; 6 7message ClippingLimits { 8 google.protobuf.FloatValue lower = 1; // -inf if not set 9 google.protobuf.FloatValue upper = 2; // +inf if not set 10} 11 12// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The 13// actual learning rates are provided as a scalar input list to the 14// SendTPUEmbeddingGradients Op indexed by their tag specified through the 15// following proto. 16message DynamicLearningRate { 17 // For tables where learning rates are dynamically computed and communicated 18 // to the TPU embedding program, a tag must be specified for the learning 19 // rate. 20 // 21 // The tag must be a non-negative integer. The total number of unique tags 22 // must be less than or equal to the number of tables in the TPU embedding 23 // configuration (a table does not specify any tag if it uses a constant 24 // learning rate, and specifies exactly one tag if it uses dynamic learning 25 // rates). 26 // 27 // All tags in the range [0, number_of_unique_tags) must be present in the TPU 28 // embedding configuration, i.e. a tag cannot be skipped if a different tag 29 // numerically greater than it is used in the configuration. 30 // 31 // If multiple tables specify the same tag, they *MUST* have 32 // the same dynamic learning rate, for example, their dynamic learning rate 33 // could be computed by the same TensorFlow sub-graph. The partitioning of the 34 // embedding layer would be more optimal if the number_of_unique_tags is as 35 // *LOW* as possible, i.e., if many tables share the same tag. 36 // 37 // The learning_rate input of the SendTPUEmbeddingGradients op is used to 38 // communicate dynamic learning rates to the TPU embedding program. 39 // The learning_rate input is a list of scalars where the size of the list is 40 // equal to the number of unique tags. The learning rate associated with a 41 // particular tag is specified by populating its corresponding index in the 42 // list of learning_rate scalars. 43 int32 tag = 1; 44} 45 46// Source of learning rate to use. 47message LearningRate { 48 oneof learning_rate { 49 float constant = 1; 50 DynamicLearningRate dynamic = 2; 51 } 52} 53 54// Each optimizer's parameter proto has a link to its documentation and CPU 55// implementation (if available) for user reference. 56 57// https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer 58// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L151 59message AdagradParameters { 60 float initial_accumulator = 1; 61} 62 63// https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer 64// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L423 65message StochasticGradientDescentParameters { 66} 67 68// https://www.tensorflow.org/api_docs/python/tf/train/FtrlOptimizer 69// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L192 70message FtrlParameters { 71 float l1 = 1; 72 float l2 = 2; 73 float lr_power = 3; 74 float initial_accum = 4; 75 float initial_linear = 5; 76} 77 78// The Adam optimizer does not implement hyper-parameter update; use the dynamic 79// learning rate feature instead, setting the learning rate to: 80// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) 81// Here, t is the current timestep. 82// 83// https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 84// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54 85// 86// Note that the code by default implements the lazy version of Adam 87// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer) 88// unless the use_non_lazy_adam parameter is set, in which case it implements 89// the normal version of Adam that updates all parameters in the embedding 90// table, even for entries that are not used in the current minibatch 91// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If 92// use_non_lazy_adam is enabled, gradient accumulation is also required to be 93// enabled in order to get correct results; a warning will be printed otherwise 94// (which may change to an error in the future). If use_sum_inside_sqrt is set, 95// the Adam variable update formula will be changed from m / (sqrt(v) + epsilon) 96// to m / sqrt(v + epsilon**2); this option improves the performance of TPU 97// training and is not expected to harm model quality. 98message AdamParameters { 99 float beta1 = 3; 100 float beta2 = 4; 101 float epsilon = 5; 102 float initial_m = 6; 103 float initial_v = 7; 104 bool use_non_lazy_adam = 8; 105 bool use_sum_inside_sqrt = 10; 106} 107 108// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer 109// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L271 110message MomentumParameters { 111 float momentum = 1; 112 bool use_nesterov = 2; 113 float initial_accum = 3; 114} 115 116// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 117// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L356 118message RmsPropParameters { 119 float rho = 1; 120 float momentum = 2; 121 float epsilon = 3; 122 float initial_ms = 4; 123 float initial_mom = 5; 124} 125 126// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 127// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L372 128message CenteredRmsPropParameters { 129 float rho = 1; 130 float momentum = 2; 131 float epsilon = 3; 132 float initial_ms = 4; 133 float initial_mom = 5; 134 float initial_mg = 6; 135} 136 137// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf 138message MdlAdagradLightParameters { 139 float l2 = 1; 140 float lr_power = 2; 141 float min_servable_mdl_benefit = 3; 142 float mdl_mix_in_margin = 4; 143 float mdl_benefit_rampup_coeff = 5; 144 float mdl_min_weight = 6; 145 float benefit_revisit_scale = 7; 146 float max_event_benefit = 8; 147 float max_total_benefit = 9; 148 float mdl_hard_limit = 10; 149 bool hard_limit_min_benefit = 11; 150 bool mdl_regularize = 12; 151 float initial_accumulator = 13; 152 float initial_weight = 14; 153 float initial_benefit = 15; 154} 155 156// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 157// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L68 158message AdadeltaParameters { 159 float rho = 1; 160 float epsilon = 2; 161 float initial_accumulator = 3; 162 float initial_update = 4; 163} 164 165// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 166// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L164 167message ProximalAdagradParameters { 168 float l1 = 1; 169 float l2 = 2; 170 float initial_accumulator = 3; 171} 172 173// Status of using gradient accumulation (doing two passes over the input 174// gradients: one to accumulate them into a temporary array and another to apply 175// them using the actual optimization algorithm). The extra message is to wrap 176// the enum for scoping. 177message GradientAccumulationStatus { 178 // if UNSPECIFIED (default), gradient accumulation is ENABLED. 179 enum Status { 180 UNSPECIFIED = 0; 181 ENABLED = 1; 182 DISABLED = 2; 183 } 184}; 185 186// Configuration proto for hot ID optimization. This is an experimental feature 187// that is currently disabled (by default). 188message HotIdOptimizerConfiguration { 189 // Whether to enable or disable hot ID optimization. 190 // If UNSPECIFIED (default), hot ID optimization is DISABLED. 191 enum Status { 192 UNSPECIFIED = 0; 193 ENABLED = 1; 194 DISABLED = 2; 195 } 196 Status status = 1; 197 198 // The following fields are never expected to be set by the TF model. However, 199 // a TF model could set them if it chooses to. If the fields are not set, 200 // meaningful default values will be chosen by the TPU software. 201 202 // Frequency above which an embedding ID is classified as hot. The valid 203 // range for the frequency is [0.0, 1.0]. The frequency of an embedding ID is 204 // defined as the ratio of the number of lookups for that ID to the total 205 // number of lookups for the embedding table. 206 float frequency_threshold = 2; 207 208 // The maximum number of hot IDs for the embedding table. If greater than 209 // max_id_count hot IDs exist for the table, the IDs with the highest 210 // frequencies are chosen. 211 int32 max_id_count = 3; 212 213 // The maximum number of slots reserved in HBM (across the entire TPU system) 214 // for storing the replicas of hot IDs for the embedding table. In future, the 215 // number of replicas for a particular hot ID could be adjusted based on its 216 // frequency. The max_slot_count value captures the total number of replicas 217 // across all hot IDs for the table. 218 int32 max_slot_count = 4; 219} 220 221message OptimizationParameters { 222 // Learning rate used for updating the embedding layer parameters. 223 LearningRate learning_rate = 13; 224 reserved 1; // Old learning rate tag. 225 226 // Limits to which to clip the weight values after the backward pass; not 227 // present means no limits are applied. 228 ClippingLimits clipping_limits = 2; 229 230 // Limits to which to clip the backward pass gradient before using it for 231 // updates; not present means no limits are applied. 232 ClippingLimits gradient_clipping_limits = 7; 233 234 // Amount of weight decay to apply; see weight_decay_optimizers.py for 235 // details. Almost all optimizers are supported with this option (MDL Adagrad 236 // Light does not work, and SGD does not behave as expected if it is enabled). 237 // Although there is no check, users who want weight decay will probably also 238 // want to enable gradient accumulation as well so that the decay will happen 239 // once per minibatch. 240 float weight_decay_factor = 16; 241 242 // Status of using gradient accumulation (doing two passes over the input 243 // gradients: one to accumulate them into a temporary array and another to 244 // apply them using the actual optimization algorithm). 245 GradientAccumulationStatus.Status gradient_accumulation_status = 17; 246 247 // Configuration proto for hot ID optimization. This is an experimental 248 // feature that is currently disabled (by default). 249 HotIdOptimizerConfiguration hot_id_optimizer_configuration = 18; 250 251 // Optimization algorithm parameters; which field is selected determines which 252 // algorithm to use. 253 oneof parameters { 254 AdagradParameters adagrad = 3; 255 StochasticGradientDescentParameters stochastic_gradient_descent = 4; 256 FtrlParameters ftrl = 5; 257 AdamParameters adam = 6; 258 MomentumParameters momentum = 8; 259 RmsPropParameters rms_prop = 9; 260 CenteredRmsPropParameters centered_rms_prop = 10; 261 MdlAdagradLightParameters mdl_adagrad_light = 11; 262 AdadeltaParameters adadelta = 12; 263 ProximalAdagradParameters proximal_adagrad = 14; 264 } 265 266 reserved 15; // Old use_gradient_accumulation. 267} 268 269// Specification of an optimization algorithm's state variables (both the main 270// value vector and any extra accumulators, etc.). This proto is only used 271// internally by the TPU software and is not exposed directly to the TF model. 272message StateVariableSpecification { 273 // Parameter name for the state variable. 274 string name = 1; 275 276 // A normal state variable that should be saved and restored in checkpoints 277 // and used as an input or output to non-debug TensorFlow ops. 278 message UserDefined { 279 // For padding embedding rows, this field specifies the initial value to be 280 // used. Separate initial values need to be specified for the embeddings and 281 // any extra accumulators. The initial values should be specified so as to 282 // maintain two invariants during model training: 283 // (1) The embedding vector multiplied by zero returns a vector containing 284 // all zeros. To maintain this invariant, the embedding values should 285 // never be NaNs or +-infinity. 286 // (2) Repeatedly applying the optimizer using a gradient vector of all 287 // zeros does not cause the embeddings or slot variables to become NaNs 288 // or +-infinity. 289 // The padding row is looked up when no embedding IDs are present for a 290 // feature. The semantics of embedding lookup dictate that the output must 291 // be zero under this scenario. 292 double padding_initial_value = 1; 293 } 294 295 // A state variable that should be filled with a constant and normally hidden 296 // from users (used for intermediate gradients being accumulated, for 297 // example). 298 message FillWithConstant { 299 double initial_value = 1; 300 } 301 302 // Usage type of this state variable. 303 oneof usage { 304 UserDefined user_defined = 2; 305 FillWithConstant fill_with_constant = 3; 306 } 307} 308