1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU embedding APIs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import math 24import re 25 26from typing import Optional 27 28import six 29 30from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 31from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc 32from tensorflow.python.eager import context 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import init_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import partitioned_variables 40from tensorflow.python.ops import state_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 44from tensorflow.python.tpu.ops import tpu_ops 45from tensorflow.python.util.tf_export import tf_export 46 47TRAINING = elc.TPUEmbeddingConfiguration.TRAINING 48INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE 49 50 51# TODO(shizhiw): a more future-proof way is to have optimization_parameter such 52# as AdagradParameters etc instead of learning_rate. 53class TableConfig( 54 collections.namedtuple('TableConfig', [ 55 'vocabulary_size', 56 'dimension', 57 'initializer', 58 'combiner', 59 'hot_id_replication', 60 'learning_rate', 61 'learning_rate_fn', 62 'optimization_parameters', 63 ])): 64 """Embedding table configuration.""" 65 66 def __new__(cls, 67 vocabulary_size, 68 dimension, 69 initializer=None, 70 combiner='mean', 71 hot_id_replication=False, 72 learning_rate=None, 73 learning_rate_fn=None, 74 optimization_parameters=None): 75 """Embedding table configuration. 76 77 Args: 78 vocabulary_size: Number of vocabulary (/rows) in the table. 79 dimension: The embedding dimension. 80 initializer: A variable initializer function to be used in embedding 81 variable initialization. If not specified, defaults to 82 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard 83 deviation `1/sqrt(dimension)`. 84 combiner: A string specifying how to reduce if there are multiple entries 85 in a single row. Currently 'mean', 'sqrtn', 'sum' and None are 86 supported, with 'mean' the default. 'sqrtn' often achieves good 87 accuracy, in particular with bag-of-words columns. For more information, 88 see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather 89 than sparse tensors. 90 hot_id_replication: If true, enables hot id replication, which can make 91 embedding lookups faster if there are some hot rows in the table. 92 learning_rate: float, static learning rate for this table. If 93 learning_rate and learning_rate_fn are both `None`, static learning rate 94 as specified in local `optimization_parameters` will be used. In case 95 local `optimization_parameters` is `None`, global 96 `optimization_parameters` in `TPUEmbedding` constructor will be used. 97 `learning_rate_fn` must be `None` if `learning_rate` is not `None. 98 learning_rate_fn: string, use dynamic learning rate given by the function. 99 This function will be passed the current global step. If learning_rate 100 and learning_rate_fn are both `None`, static learning rate as specified 101 in `optimization_parameters` is used. `learning_rate` must be `None` if 102 `learning_rate_fn` is not `None. 103 optimization_parameters: `AdagradParameters`, `AdamParameters`, 104 `Stochasticgradientdescentparameters`. Specifies table level optimizer. 105 If it's `None` global optimizer in `TPUEmbedding` constructor is used. 106 107 Returns: 108 `TableConfig`. 109 110 Raises: 111 ValueError: if `vocabulary_size` is not positive integer. 112 ValueError: if `dimension` is not positive integer. 113 ValueError: if `initializer` is specified and is not callable. 114 ValueError: if `combiner` is not supported. 115 ValueError: if `learning_rate` and `learning_rate_fn` are both not 116 `None`. 117 """ 118 if not isinstance(vocabulary_size, int) or vocabulary_size < 1: 119 raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size)) 120 121 if not isinstance(dimension, int) or dimension < 1: 122 raise ValueError('Invalid dimension {}.'.format(dimension)) 123 124 if (initializer is not None) and (not callable(initializer)): 125 raise ValueError('initializer must be callable if specified.') 126 if initializer is None: 127 initializer = init_ops.truncated_normal_initializer( 128 mean=0.0, stddev=1 / math.sqrt(dimension)) 129 130 if combiner not in ('mean', 'sum', 'sqrtn', None): 131 raise ValueError('Invalid combiner {}'.format(combiner)) 132 133 if learning_rate is not None and learning_rate_fn is not None: 134 raise ValueError('At most one of learning_rate and learning_rate_fn ' 135 'can be None; got {} and {}'.format( 136 learning_rate, learning_rate_fn)) 137 138 if optimization_parameters is not None: 139 if not isinstance(optimization_parameters, _OptimizationParameters): 140 raise ValueError('`optimization_parameters` must inherit from ' 141 '`_OptimizationParameters`. ' 142 '`type(optimization_parameters)`={}'.format( 143 type(optimization_parameters))) 144 145 return super(TableConfig, 146 cls).__new__(cls, vocabulary_size, dimension, initializer, 147 combiner, hot_id_replication, learning_rate, 148 learning_rate_fn, optimization_parameters) 149 150 151class FeatureConfig( 152 collections.namedtuple('FeatureConfig', 153 ['table_id', 'max_sequence_length', 'weight_key'])): 154 """Feature configuration.""" 155 156 def __new__(cls, table_id, max_sequence_length=0, weight_key=None): 157 """Feature configuration. 158 159 Args: 160 table_id: Which table the feature is uses for embedding lookups. 161 max_sequence_length: If positive, the feature is a sequence feature with 162 the corresponding maximum sequence length. If the sequence is longer 163 than this, it will be truncated. If 0, the feature is not a sequence 164 feature. 165 weight_key: If using weights for the combiner, this key specifies which 166 input feature contains the weights. 167 168 Returns: 169 `FeatureConfig`. 170 171 Raises: 172 ValueError: if `max_sequence_length` non-negative. 173 """ 174 if not isinstance(max_sequence_length, int) or max_sequence_length < 0: 175 raise ValueError( 176 'Invalid max_sequence_length {}.'.format(max_sequence_length)) 177 178 return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length, 179 weight_key) 180 181 182class EnqueueData( 183 collections.namedtuple( 184 'EnqueueData', 185 ['embedding_indices', 'sample_indices', 'aggregation_weights'])): 186 """Data to be enqueued through generate_enqueue_ops().""" 187 188 def __new__(cls, 189 embedding_indices, 190 sample_indices=None, 191 aggregation_weights=None): 192 """Data to be enqueued through generate_enqueue_ops(). 193 194 Args: 195 embedding_indices: A rank 1 Tensor, indices into the embedding tables. It 196 corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32 197 and int64 are allowed and will be converted to int32 internally. 198 sample_indices: A rank 2 Tensor specifying the training example to which 199 the corresponding embedding_indices and aggregation_weights values 200 belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). 201 If it is None, we assume each embedding_indices belongs to a different 202 sample. Both int32 and int64 are allowed and will be converted to int32 203 internally. 204 aggregation_weights: A rank 1 Tensor containing aggregation weights. It 205 corresponds to sp_weights.values in embedding_lookup_sparse(). If it is 206 None, we assume all weights are 1. Both float32 and float64 are allowed 207 and will be converted to float32 internally. 208 209 Returns: 210 An EnqueueData tuple. 211 212 """ 213 return super(EnqueueData, cls).__new__(cls, embedding_indices, 214 sample_indices, aggregation_weights) 215 216 @staticmethod 217 def from_sparse_tensor(sp_tensor, weights=None): 218 return EnqueueData( 219 sp_tensor.values, 220 sp_tensor.indices, 221 aggregation_weights=weights.values if weights is not None else None) 222 223 224class RaggedEnqueueData( 225 collections.namedtuple( 226 'RaggedEnqueueData', 227 ['embedding_indices', 'sample_splits', 'aggregation_weights'])): 228 """RaggedTensor Data to be enqueued through generate_enqueue_ops().""" 229 230 def __new__(cls, 231 embedding_indices, 232 sample_splits=None, 233 aggregation_weights=None): 234 """Data to be enqueued through generate_enqueue_ops(). 235 236 Args: 237 embedding_indices: A rank 1 Tensor, indices into the embedding tables. It 238 corresponds to ids.values in embedding_lookup(), when ids is a 239 RaggedTensor. Both int32 and int64 are allowed and will be converted to 240 int32 internally. 241 sample_splits: A rank 1 Tensor specifying the break points for splitting 242 embedding_indices and aggregation_weights into rows. It corresponds to 243 ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both 244 int32 and int64 are allowed and will be converted to int32 internally. 245 aggregation_weights: A rank 1 Tensor containing per training example 246 aggregation weights. It corresponds to the values field of a 247 RaggedTensor with the same row_splits as ids in embedding_lookup(), when 248 ids is a RaggedTensor. 249 250 Returns: 251 An RaggedEnqueueData tuple. 252 253 """ 254 return super(RaggedEnqueueData, 255 cls).__new__(cls, embedding_indices, sample_splits, 256 aggregation_weights) 257 258 @staticmethod 259 def from_ragged_tensor(rg_tensor, weights=None): 260 return RaggedEnqueueData( 261 rg_tensor.values, 262 rg_tensor.row_splits, 263 aggregation_weights=weights.values if weights is not None else None) 264 265 266def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list): 267 """Convenient function for generate_enqueue_ops(). 268 269 Args: 270 sp_tensors_list: a list of dictionary mapping from string of feature names 271 to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the 272 same host should be contiguous on the list. 273 274 Returns: 275 enqueue_datas_list: a list of dictionary mapping from string 276 of feature names to EnqueueData. Each dictionary is for one 277 TPU core. Dictionaries for the same host should be contiguous 278 on the list. 279 280 """ 281 enqueue_datas_list = [] 282 for sp_tensors in sp_tensors_list: 283 enqueue_datas = collections.OrderedDict( 284 (k, EnqueueData.from_sparse_tensor(v)) 285 for k, v in six.iteritems(sp_tensors)) 286 enqueue_datas_list.append(enqueue_datas) 287 return enqueue_datas_list 288 289 290def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list): 291 """Convenient function for generate_enqueue_ops(). 292 293 Args: 294 rg_tensors_list: a list of dictionary mapping from string of feature names 295 to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the 296 same host should be contiguous on the list. 297 298 Returns: 299 enqueue_datas_list: a list of dictionary mapping from string 300 of feature names to RaggedEnqueueData. Each dictionary is for one 301 TPU core. Dictionaries for the same host should be contiguous 302 on the list. 303 304 """ 305 enqueue_datas_list = [] 306 for rg_tensors in rg_tensors_list: 307 enqueue_datas = collections.OrderedDict( 308 (k, RaggedEnqueueData.from_ragged_tensor(v)) 309 for k, v in six.iteritems(rg_tensors)) 310 enqueue_datas_list.append(enqueue_datas) 311 return enqueue_datas_list 312 313 314AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames', 315 ['m', 'v']) 316 317AdagradSlotVariableName = collections.namedtuple('AdagradSlotVariableName', 318 ['accumulator']) 319 320MomentumSlotVariableName = collections.namedtuple('MomentumSlotVariableName', 321 ['momenta']) 322 323RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames', 324 ['ms', 'mom']) 325 326ProximalAdagradSlotVariableName = collections.namedtuple( 327 'ProximalAdagradSlotVariableName', ['accumulator']) 328 329FtrlSlotVariableName = collections.namedtuple('FtrlSlotVariableName', 330 ['accumulator', 'linear']) 331 332ProximalYogiSlotVariableNames = collections.namedtuple( 333 'ProximalYogiSlotVariableNames', ['v', 'm']) 334 335FrequencyEstimatorSlotVariableName = collections.namedtuple( 336 'FrequencyEstimatorSlotVariableName', ['last_hit_step']) 337 338AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v']) 339 340MomentumSlotVariable = collections.namedtuple('MomentumSlotVariable', 341 ['momenta']) 342 343RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables', 344 ['ms', 'mom']) 345 346AdagradSlotVariable = collections.namedtuple('AdagradSlotVariable', 347 ['accumulator']) 348 349ProximalAdagradSlotVariable = collections.namedtuple( 350 'ProximalAdagradSlotVariable', ['accumulator']) 351 352FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable', 353 ['accumulator', 'linear']) 354 355ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables', 356 ['v', 'm']) 357 358FrequencyEstimatorSlotVariables = collections.namedtuple( 359 'FrequencyEstimatorSlotVariables', ['last_hit_step']) 360 361VariablesAndOps = collections.namedtuple('VariablesAndOps', [ 362 'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops', 363 'retrieve_ops' 364]) 365 366 367class _OptimizationParameters(object): 368 """Parameters common to all optimizations.""" 369 370 def __init__( 371 self, 372 learning_rate: float, 373 use_gradient_accumulation: bool, 374 clip_weight_min: Optional[float], 375 clip_weight_max: Optional[float], 376 weight_decay_factor: Optional[float], 377 multiply_weight_decay_factor_by_learning_rate: Optional[bool], 378 clip_gradient_min: Optional[float] = None, 379 clip_gradient_max: Optional[float] = None, 380 ): 381 self.learning_rate = learning_rate 382 self.use_gradient_accumulation = use_gradient_accumulation 383 self.clip_weight_min = clip_weight_min 384 self.clip_weight_max = clip_weight_max 385 self.weight_decay_factor = weight_decay_factor 386 self.multiply_weight_decay_factor_by_learning_rate = ( 387 multiply_weight_decay_factor_by_learning_rate) 388 self.clip_gradient_min = clip_gradient_min 389 self.clip_gradient_max = clip_gradient_max 390 391 if not use_gradient_accumulation and (clip_gradient_min is not None or 392 clip_gradient_max is not None): 393 ValueError('When using gradient clipping limits, gradient accumulation ' 394 'must be enabled.') 395 396 397@tf_export(v1=['tpu.experimental.AdagradParameters']) 398class AdagradParameters(_OptimizationParameters): 399 """Optimization parameters for Adagrad with TPU embeddings. 400 401 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 402 `optimization_parameters` argument to set the optimizer and its parameters. 403 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 404 for more details. 405 406 ``` 407 estimator = tf.estimator.tpu.TPUEstimator( 408 ... 409 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 410 ... 411 optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1), 412 ...)) 413 ``` 414 415 """ 416 417 def __init__( 418 self, 419 learning_rate: float, 420 initial_accumulator: float = 0.1, 421 use_gradient_accumulation: bool = True, 422 clip_weight_min: Optional[float] = None, 423 clip_weight_max: Optional[float] = None, 424 weight_decay_factor: Optional[float] = None, 425 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 426 clip_gradient_min: Optional[float] = None, 427 clip_gradient_max: Optional[float] = None, 428 ): 429 """Optimization parameters for Adagrad. 430 431 Args: 432 learning_rate: used for updating embedding table. 433 initial_accumulator: initial accumulator for Adagrad. 434 use_gradient_accumulation: setting this to `False` makes embedding 435 gradients calculation less accurate but faster. Please see 436 `optimization_parameters.proto` for details. 437 clip_weight_min: the minimum value to clip by; None means -infinity. 438 clip_weight_max: the maximum value to clip by; None means +infinity. 439 weight_decay_factor: amount of weight decay to apply; None means that the 440 weights are not decayed. 441 multiply_weight_decay_factor_by_learning_rate: if true, 442 `weight_decay_factor` is multiplied by the current learning rate. 443 clip_gradient_min: the minimum value to clip by; None means -infinity. 444 Gradient accumulation must be set to true if this is set. 445 clip_gradient_max: the maximum value to clip by; None means +infinity. 446 Gradient accumulation must be set to true if this is set. 447 """ 448 super(AdagradParameters, self).__init__( 449 learning_rate=learning_rate, 450 use_gradient_accumulation=use_gradient_accumulation, 451 clip_weight_min=clip_weight_min, 452 clip_weight_max=clip_weight_max, 453 weight_decay_factor=weight_decay_factor, 454 multiply_weight_decay_factor_by_learning_rate=( 455 multiply_weight_decay_factor_by_learning_rate), 456 clip_gradient_min=clip_gradient_min, 457 clip_gradient_max=clip_gradient_max, 458 ) 459 if initial_accumulator <= 0: 460 raise ValueError('Adagrad initial_accumulator must be positive') 461 self.initial_accumulator = initial_accumulator 462 463 464class ProximalAdagradParameters(_OptimizationParameters): 465 """Optimization parameters for ProximalAdagrad with TPU embeddings. 466 467 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 468 `optimization_parameters` argument to set the optimizer and its parameters. 469 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 470 for more details. 471 """ 472 473 def __init__( 474 self, 475 learning_rate: float, 476 initial_accumulator: float = 0.1, 477 l1_regularization_strength: float = 0.0, 478 l2_regularization_strength: float = 0.0, 479 use_gradient_accumulation: bool = True, 480 clip_weight_min: Optional[float] = None, 481 clip_weight_max: Optional[float] = None, 482 weight_decay_factor: Optional[float] = None, 483 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 484 clip_gradient_min: Optional[float] = None, 485 clip_gradient_max: Optional[float] = None, 486 ): 487 """Optimization parameters for Adagrad. 488 489 Args: 490 learning_rate: used for updating embedding table. 491 initial_accumulator: initial accumulator for Adagrad. 492 l1_regularization_strength: A float value, must be greater than or equal 493 to zero. 494 l2_regularization_strength: A float value, must be greater than or equal 495 to zero. 496 use_gradient_accumulation: setting this to `False` makes embedding 497 gradients calculation less accurate but faster. Please see 498 `optimization_parameters.proto` for details. for details. 499 clip_weight_min: the minimum value to clip by; None means -infinity. 500 clip_weight_max: the maximum value to clip by; None means +infinity. 501 weight_decay_factor: amount of weight decay to apply; None means that the 502 weights are not decayed. 503 multiply_weight_decay_factor_by_learning_rate: if true, 504 `weight_decay_factor` is multiplied by the current learning rate. 505 clip_gradient_min: the minimum value to clip by; None means -infinity. 506 Gradient accumulation must be set to true if this is set. 507 clip_gradient_max: the maximum value to clip by; None means +infinity. 508 Gradient accumulation must be set to true if this is set. 509 """ 510 super(ProximalAdagradParameters, self).__init__( 511 learning_rate=learning_rate, 512 use_gradient_accumulation=use_gradient_accumulation, 513 clip_weight_min=clip_weight_min, 514 clip_weight_max=clip_weight_max, 515 weight_decay_factor=weight_decay_factor, 516 multiply_weight_decay_factor_by_learning_rate=( 517 multiply_weight_decay_factor_by_learning_rate), 518 clip_gradient_min=clip_gradient_min, 519 clip_gradient_max=clip_gradient_max, 520 ) 521 if initial_accumulator <= 0: 522 raise ValueError('Adagrad initial_accumulator must be positive') 523 if l1_regularization_strength < 0.: 524 raise ValueError('l1_regularization_strength must be greater than or ' 525 'equal to 0. got {}.'.format(l1_regularization_strength)) 526 527 if l2_regularization_strength < 0.: 528 raise ValueError('l2_regularization_strength must be greater than or ' 529 'equal to 0. got {}.'.format(l2_regularization_strength)) 530 531 self.initial_accumulator = initial_accumulator 532 self.l1_regularization_strength = l1_regularization_strength 533 self.l2_regularization_strength = l2_regularization_strength 534 535 536@tf_export(v1=['tpu.experimental.AdamParameters']) 537class AdamParameters(_OptimizationParameters): 538 """Optimization parameters for Adam with TPU embeddings. 539 540 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 541 `optimization_parameters` argument to set the optimizer and its parameters. 542 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 543 for more details. 544 545 ``` 546 estimator = tf.estimator.tpu.TPUEstimator( 547 ... 548 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 549 ... 550 optimization_parameters=tf.tpu.experimental.AdamParameters(0.1), 551 ...)) 552 ``` 553 554 """ 555 556 def __init__( 557 self, 558 learning_rate: float, 559 beta1: float = 0.9, 560 beta2: float = 0.999, 561 epsilon: float = 1e-08, 562 lazy_adam: bool = True, 563 sum_inside_sqrt: bool = True, 564 use_gradient_accumulation: bool = True, 565 clip_weight_min: Optional[float] = None, 566 clip_weight_max: Optional[float] = None, 567 weight_decay_factor: Optional[float] = None, 568 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 569 clip_gradient_min: Optional[float] = None, 570 clip_gradient_max: Optional[float] = None, 571 ): 572 """Optimization parameters for Adam. 573 574 Args: 575 learning_rate: a floating point value. The learning rate. 576 beta1: A float value. The exponential decay rate for the 1st moment 577 estimates. 578 beta2: A float value. The exponential decay rate for the 2nd moment 579 estimates. 580 epsilon: A small constant for numerical stability. 581 lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See 582 `optimization_parameters.proto` for details. 583 sum_inside_sqrt: This improves training speed. Please see 584 `optimization_parameters.proto` for details. 585 use_gradient_accumulation: setting this to `False` makes embedding 586 gradients calculation less accurate but faster. Please see 587 `optimization_parameters.proto` for details. 588 clip_weight_min: the minimum value to clip by; None means -infinity. 589 clip_weight_max: the maximum value to clip by; None means +infinity. 590 weight_decay_factor: amount of weight decay to apply; None means that the 591 weights are not decayed. 592 multiply_weight_decay_factor_by_learning_rate: if true, 593 `weight_decay_factor` is multiplied by the current learning rate. 594 clip_gradient_min: the minimum value to clip by; None means -infinity. 595 Gradient accumulation must be set to true if this is set. 596 clip_gradient_max: the maximum value to clip by; None means +infinity. 597 Gradient accumulation must be set to true if this is set. 598 """ 599 super(AdamParameters, self).__init__( 600 learning_rate=learning_rate, 601 use_gradient_accumulation=use_gradient_accumulation, 602 clip_weight_min=clip_weight_min, 603 clip_weight_max=clip_weight_max, 604 weight_decay_factor=weight_decay_factor, 605 multiply_weight_decay_factor_by_learning_rate=( 606 multiply_weight_decay_factor_by_learning_rate), 607 clip_gradient_min=clip_gradient_min, 608 clip_gradient_max=clip_gradient_max, 609 ) 610 if beta1 < 0. or beta1 >= 1.: 611 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) 612 if beta2 < 0. or beta2 >= 1.: 613 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) 614 if epsilon <= 0.: 615 raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) 616 if not use_gradient_accumulation and not lazy_adam: 617 raise ValueError( 618 'When disabling Lazy Adam, gradient accumulation must be used.') 619 620 self.beta1 = beta1 621 self.beta2 = beta2 622 self.epsilon = epsilon 623 self.lazy_adam = lazy_adam 624 self.sum_inside_sqrt = sum_inside_sqrt 625 626 627@tf_export(v1=['tpu.experimental.FtrlParameters']) 628class FtrlParameters(_OptimizationParameters): 629 """Optimization parameters for Ftrl with TPU embeddings. 630 631 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 632 `optimization_parameters` argument to set the optimizer and its parameters. 633 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 634 for more details. 635 636 ``` 637 estimator = tf.estimator.tpu.TPUEstimator( 638 ... 639 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 640 ... 641 optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1), 642 ...)) 643 ``` 644 645 """ 646 647 def __init__( 648 self, 649 learning_rate: float, 650 learning_rate_power: float = -0.5, 651 initial_accumulator_value: float = 0.1, 652 l1_regularization_strength: float = 0.0, 653 l2_regularization_strength: float = 0.0, 654 use_gradient_accumulation: bool = True, 655 clip_weight_min: Optional[float] = None, 656 clip_weight_max: Optional[float] = None, 657 weight_decay_factor: Optional[float] = None, 658 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 659 multiply_linear_by_learning_rate: bool = False, 660 beta: float = 0, 661 allow_zero_accumulator: bool = False, 662 clip_gradient_min: Optional[float] = None, 663 clip_gradient_max: Optional[float] = None, 664 ): 665 """Optimization parameters for Ftrl. 666 667 Implements FTRL as described in the following [paper]( 668 https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf) 669 670 Args: 671 learning_rate: a floating point value. The learning rate. 672 learning_rate_power: A float value, must be less or equal to zero. 673 Controls how the learning rate decreases during training. Use zero for a 674 fixed learning rate. See section 3.1 in the 675 [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). 676 initial_accumulator_value: The starting value for accumulators. Only zero 677 or positive values are allowed. 678 l1_regularization_strength: A float value, must be greater than or equal 679 to zero. 680 l2_regularization_strength: A float value, must be greater than or equal 681 to zero. 682 use_gradient_accumulation: setting this to `False` makes embedding 683 gradients calculation less accurate but faster. Please see 684 `optimization_parameters.proto` for details. for details. 685 clip_weight_min: the minimum value to clip by; None means -infinity. 686 clip_weight_max: the maximum value to clip by; None means +infinity. 687 weight_decay_factor: amount of weight decay to apply; None means that the 688 weights are not decayed. 689 multiply_weight_decay_factor_by_learning_rate: if true, 690 `weight_decay_factor` is multiplied by the current learning rate. 691 multiply_linear_by_learning_rate: When true, multiplies the usages of the 692 linear slot in the weight update by the learning rate. This is useful 693 when ramping up learning rate from 0 (which would normally produce 694 NaNs). 695 beta: The beta parameter for FTRL. 696 allow_zero_accumulator: Changes the implementation of the square root to 697 allow for the case of initial_accumulator_value being zero. This will 698 cause a slight performance drop. 699 clip_gradient_min: the minimum value to clip by; None means -infinity. 700 Gradient accumulation must be set to true if this is set. 701 clip_gradient_max: the maximum value to clip by; None means +infinity. 702 Gradient accumulation must be set to true if this is set. 703 """ 704 super(FtrlParameters, self).__init__( 705 learning_rate=learning_rate, 706 use_gradient_accumulation=use_gradient_accumulation, 707 clip_weight_min=clip_weight_min, 708 clip_weight_max=clip_weight_max, 709 weight_decay_factor=weight_decay_factor, 710 multiply_weight_decay_factor_by_learning_rate=( 711 multiply_weight_decay_factor_by_learning_rate), 712 clip_gradient_min=clip_gradient_min, 713 clip_gradient_max=clip_gradient_max, 714 ) 715 if learning_rate_power > 0.: 716 raise ValueError('learning_rate_power must be less than or equal to 0. ' 717 'got {}.'.format(learning_rate_power)) 718 719 if initial_accumulator_value < 0.: 720 raise ValueError('initial_accumulator_value must be greater than or equal' 721 ' to 0. got {}.'.format(initial_accumulator_value)) 722 723 if l1_regularization_strength < 0.: 724 raise ValueError('l1_regularization_strength must be greater than or ' 725 'equal to 0. got {}.'.format(l1_regularization_strength)) 726 727 if l2_regularization_strength < 0.: 728 raise ValueError('l2_regularization_strength must be greater than or ' 729 'equal to 0. got {}.'.format(l2_regularization_strength)) 730 731 self.learning_rate_power = learning_rate_power 732 self.initial_accumulator_value = initial_accumulator_value 733 self.initial_linear_value = 0.0 734 self.l1_regularization_strength = l1_regularization_strength 735 self.l2_regularization_strength = l2_regularization_strength 736 self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate 737 self.beta = beta 738 self.allow_zero_accumulator = allow_zero_accumulator 739 740 741class ProximalYogiParameters(_OptimizationParameters): 742 # pylint: disable=line-too-long 743 """Optimization parameters for Proximal Yogi with TPU embeddings. 744 745 Implements the Yogi optimizer as described in 746 [Adaptive Methods for Nonconvex 747 Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization). 748 749 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 750 `optimization_parameters` argument to set the optimizer and its parameters. 751 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 752 for more details. 753 """ 754 755 # pylint: enable=line-too-long 756 757 def __init__( 758 self, 759 learning_rate: float = 0.01, 760 beta1: float = 0.9, 761 beta2: float = 0.999, 762 epsilon: float = 1e-3, 763 l1_regularization_strength: float = 0.0, 764 l2_regularization_strength: float = 0.0, 765 initial_accumulator_value: float = 1e-6, 766 use_gradient_accumulation: bool = True, 767 clip_weight_min: Optional[float] = None, 768 clip_weight_max: Optional[float] = None, 769 weight_decay_factor: Optional[float] = None, 770 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 771 clip_gradient_min: Optional[float] = None, 772 clip_gradient_max: Optional[float] = None, 773 ): 774 """Optimization parameters for Proximal Yogi. 775 776 Args: 777 learning_rate: a floating point value. The learning rate. 778 beta1: A float value. The exponential decay rate for the 1st moment 779 estimates. 780 beta2: A float value. The exponential decay rate for the 2nd moment 781 estimates. 782 epsilon: A small constant for numerical stability. 783 l1_regularization_strength: A float value, must be greater than or equal 784 to zero. 785 l2_regularization_strength: A float value, must be greater than or equal 786 to zero. 787 initial_accumulator_value: The starting value for accumulators. Only zero 788 or positive values are allowed. 789 use_gradient_accumulation: setting this to `False` makes embedding 790 gradients calculation less accurate but faster. Please see 791 `optimization_parameters.proto` for details. for details. 792 clip_weight_min: the minimum value to clip by; None means -infinity. 793 clip_weight_max: the maximum value to clip by; None means +infinity. 794 weight_decay_factor: amount of weight decay to apply; None means that the 795 weights are not decayed. 796 multiply_weight_decay_factor_by_learning_rate: if true, 797 `weight_decay_factor` is multiplied by the current learning rate. 798 clip_gradient_min: the minimum value to clip by; None means -infinity. 799 Gradient accumulation must be set to true if this is set. 800 clip_gradient_max: the maximum value to clip by; None means +infinity. 801 Gradient accumulation must be set to true if this is set. 802 """ 803 super(ProximalYogiParameters, self).__init__( 804 learning_rate=learning_rate, 805 use_gradient_accumulation=use_gradient_accumulation, 806 clip_weight_min=clip_weight_min, 807 clip_weight_max=clip_weight_max, 808 weight_decay_factor=weight_decay_factor, 809 multiply_weight_decay_factor_by_learning_rate=( 810 multiply_weight_decay_factor_by_learning_rate), 811 clip_gradient_min=clip_gradient_min, 812 clip_gradient_max=clip_gradient_max, 813 ) 814 if beta1 < 0. or beta1 >= 1.: 815 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) 816 if beta2 < 0. or beta2 >= 1.: 817 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) 818 if epsilon <= 0.: 819 raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) 820 if l1_regularization_strength < 0.: 821 raise ValueError('l1_regularization_strength must be greater than or ' 822 'equal to 0. got {}.'.format(l1_regularization_strength)) 823 if l2_regularization_strength < 0.: 824 raise ValueError('l2_regularization_strength must be greater than or ' 825 'equal to 0. got {}.'.format(l2_regularization_strength)) 826 827 self.beta1 = beta1 828 self.beta2 = beta2 829 self.epsilon = epsilon 830 self.l1_regularization_strength = l1_regularization_strength 831 self.l2_regularization_strength = l2_regularization_strength 832 self.initial_accumulator_value = initial_accumulator_value 833 834 835class MomentumParameters(_OptimizationParameters): 836 """Optimization parameters for Momentum with TPU embeddings. 837 838 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 839 `optimization_parameters` argument to set the optimizer and its parameters. 840 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 841 for more details. 842 843 ``` 844 estimator = tf.estimator.tpu.TPUEstimator( 845 ... 846 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 847 ... 848 optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), 849 ...)) 850 ``` 851 852 """ 853 854 def __init__( 855 self, 856 learning_rate: float, 857 momentum: float, 858 use_nesterov: bool = False, 859 use_gradient_accumulation: bool = True, 860 clip_weight_min: Optional[float] = None, 861 clip_weight_max: Optional[float] = None, 862 weight_decay_factor: Optional[float] = None, 863 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 864 clip_gradient_min: Optional[float] = None, 865 clip_gradient_max: Optional[float] = None, 866 ): 867 """Optimization parameters for momentum. 868 869 Args: 870 learning_rate: a floating point value. The learning rate. 871 momentum: A `Tensor` or a floating point value. The momentum. 872 use_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al., 873 2013). This implementation always computes gradients at the value of the 874 variable(s) passed to the optimizer. Using Nesterov Momentum makes the 875 variable(s) track the values called `theta_t + mu*v_t` in the paper. 876 This implementation is an approximation of the original formula, valid 877 for high values of momentum. It will compute the "adjusted gradient" in 878 NAG by assuming that the new gradient will be estimated by the current 879 average gradient plus the product of momentum and the change in the 880 average gradient. 881 use_gradient_accumulation: setting this to `False` makes embedding 882 gradients calculation less accurate but faster. Please see 883 `optimization_parameters.proto` for details. 884 clip_weight_min: the minimum value to clip by; None means -infinity. 885 clip_weight_max: the maximum value to clip by; None means +infinity. 886 weight_decay_factor: amount of weight decay to apply; None means that the 887 weights are not decayed. 888 multiply_weight_decay_factor_by_learning_rate: if true, 889 `weight_decay_factor` is multiplied by the current learning rate. 890 clip_gradient_min: the minimum value to clip by; None means -infinity. 891 Gradient accumulation must be set to true if this is set. 892 clip_gradient_max: the maximum value to clip by; None means +infinity. 893 Gradient accumulation must be set to true if this is set. 894 """ 895 super(MomentumParameters, self).__init__( 896 learning_rate=learning_rate, 897 use_gradient_accumulation=use_gradient_accumulation, 898 clip_weight_min=clip_weight_min, 899 clip_weight_max=clip_weight_max, 900 weight_decay_factor=weight_decay_factor, 901 multiply_weight_decay_factor_by_learning_rate=( 902 multiply_weight_decay_factor_by_learning_rate), 903 clip_gradient_min=clip_gradient_min, 904 clip_gradient_max=clip_gradient_max, 905 ) 906 self.momentum = momentum 907 self.use_nesterov = use_nesterov 908 909 910class RMSPropParameters(_OptimizationParameters): 911 """Optimization parameters for RMSProp with TPU embeddings. 912 913 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 914 `optimization_parameters` argument to set the optimizer and its parameters. 915 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 916 for more details. 917 918 ``` 919 estimator = tf.estimator.tpu.TPUEstimator( 920 ... 921 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 922 ... 923 optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), 924 ...)) 925 ``` 926 927 """ 928 929 def __init__( 930 self, 931 learning_rate: float, 932 rho: float, 933 momentum: float, 934 epsilon: float, 935 use_gradient_accumulation: bool = True, 936 clip_weight_min: Optional[float] = None, 937 clip_weight_max: Optional[float] = None, 938 weight_decay_factor: Optional[float] = None, 939 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 940 clip_gradient_min: Optional[float] = None, 941 clip_gradient_max: Optional[float] = None, 942 ): 943 """Optimization parameters for RMS prop. 944 945 Args: 946 learning_rate: a floating point value. The learning rate. 947 rho: Discounting factor for the history/coming gradient 948 momentum: A scalar tensor. 949 epsilon: Small value to avoid zero denominator. 950 use_gradient_accumulation: setting this to `False` makes embedding 951 gradients calculation less accurate but faster. Please see 952 `optimization_parameters.proto` for details. for details. 953 clip_weight_min: the minimum value to clip by; None means -infinity. 954 clip_weight_max: the maximum value to clip by; None means +infinity. 955 weight_decay_factor: amount of weight decay to apply; None means that the 956 weights are not decayed. 957 multiply_weight_decay_factor_by_learning_rate: if true, 958 `weight_decay_factor` is multiplied by the current learning rate. 959 clip_gradient_min: the minimum value to clip by; None means -infinity. 960 Gradient accumulation must be set to true if this is set. 961 clip_gradient_max: the maximum value to clip by; None means +infinity. 962 Gradient accumulation must be set to true if this is set. 963 """ 964 super(RMSPropParameters, self).__init__( 965 learning_rate=learning_rate, 966 use_gradient_accumulation=use_gradient_accumulation, 967 clip_weight_min=clip_weight_min, 968 clip_weight_max=clip_weight_max, 969 weight_decay_factor=weight_decay_factor, 970 multiply_weight_decay_factor_by_learning_rate=( 971 multiply_weight_decay_factor_by_learning_rate), 972 clip_gradient_min=clip_gradient_min, 973 clip_gradient_max=clip_gradient_max, 974 ) 975 self.rho = rho 976 self.momentum = momentum 977 self.epsilon = epsilon 978 979 980@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters']) 981class StochasticGradientDescentParameters(_OptimizationParameters): 982 """Optimization parameters for stochastic gradient descent for TPU embeddings. 983 984 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 985 `optimization_parameters` argument to set the optimizer and its parameters. 986 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 987 for more details. 988 989 ``` 990 estimator = tf.estimator.tpu.TPUEstimator( 991 ... 992 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 993 ... 994 optimization_parameters=( 995 tf.tpu.experimental.StochasticGradientDescentParameters(0.1)))) 996 ``` 997 998 """ 999 1000 def __init__( 1001 self, 1002 learning_rate: float, 1003 clip_weight_min: Optional[float] = None, 1004 clip_weight_max: Optional[float] = None, 1005 weight_decay_factor: Optional[float] = None, 1006 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 1007 clip_gradient_min: Optional[float] = None, 1008 clip_gradient_max: Optional[float] = None, 1009 ): 1010 """Optimization parameters for stochastic gradient descent. 1011 1012 Args: 1013 learning_rate: a floating point value. The learning rate. 1014 clip_weight_min: the minimum value to clip by; None means -infinity. 1015 clip_weight_max: the maximum value to clip by; None means +infinity. 1016 weight_decay_factor: amount of weight decay to apply; None means that the 1017 weights are not decayed. 1018 multiply_weight_decay_factor_by_learning_rate: if true, 1019 `weight_decay_factor` is multiplied by the current learning rate. 1020 clip_gradient_min: the minimum value to clip by; None means -infinity. 1021 clip_gradient_max: the maximum value to clip by; None means +infinity. 1022 """ 1023 # Gradient accumulation is generally a no-op for SGD, but if gradient 1024 # clipping is enabled, then we must also enable gradient accumulation. 1025 # In the other optimizers this up to the user, but we don't give the user 1026 # the option to turn gradient accumulation on or off for SGD. 1027 use_gradient_accumulation = False 1028 if (clip_gradient_min is not None or clip_gradient_max is not None): 1029 use_gradient_accumulation = True 1030 super(StochasticGradientDescentParameters, self).__init__( 1031 learning_rate=learning_rate, 1032 use_gradient_accumulation=use_gradient_accumulation, 1033 clip_weight_min=clip_weight_min, 1034 clip_weight_max=clip_weight_max, 1035 weight_decay_factor=weight_decay_factor, 1036 multiply_weight_decay_factor_by_learning_rate=( 1037 multiply_weight_decay_factor_by_learning_rate), 1038 clip_gradient_min=clip_gradient_min, 1039 clip_gradient_max=clip_gradient_max, 1040 ) 1041 1042 1043class FrequencyEstimatorParameters(_OptimizationParameters): 1044 """Optimization parameters for Frequency Estimator TPU embeddings. 1045 1046 This is a non-standard optimizer, which returns the estimated frequency of 1047 lookup for the feature passed to it. It should only be used on a table of 1048 width 1. The gradient fed back to the TPU embedding should always be zero. 1049 This can be acomplished via using `tf.stop_gradients` on the feature before 1050 using it. 1051 1052 You must use the dynamic learning rate mechanism to set the 'learning rate' 1053 for this table to be the a float32 cast of the global training step counter. 1054 1055 See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more 1056 details on this optimizer. 1057 1058 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 1059 `optimization_parameters` argument to set the optimizer and its parameters. 1060 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 1061 for more details. 1062 1063 ``` 1064 estimator = tf.estimator.tpu.TPUEstimator( 1065 ... 1066 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 1067 ... 1068 optimization_parameters=FrequencyEstimatorParameters(0.1), 1069 ...)) 1070 ``` 1071 1072 """ 1073 1074 def __init__(self, tau: float, max_delta: float, outlier_threshold: float, 1075 weight_exponent: float): 1076 """Optimization parameters for frequency estimator. 1077 1078 Args: 1079 tau: Learning rate between (0, 1) that is used to update the array. 1080 max_delta: Maximum value of delta, the difference between the current 1081 global step and the last global step at which the row was sampled. 1082 outlier_threshold: Threshold used to determine whether the current update 1083 is an outlier. 1084 weight_exponent: The weight exponent used to transform the estimated delta 1085 into weights. 1086 """ 1087 super(FrequencyEstimatorParameters, self).__init__( 1088 learning_rate=1.0, 1089 use_gradient_accumulation=True, 1090 clip_weight_min=None, 1091 clip_weight_max=None, 1092 weight_decay_factor=None, 1093 multiply_weight_decay_factor_by_learning_rate=None, 1094 ) 1095 self.tau = tau 1096 self.max_delta = max_delta 1097 self.outlier_threshold = outlier_threshold 1098 self.weight_exponent = weight_exponent 1099 1100 1101DeviceConfig = collections.namedtuple('DeviceConfig', 1102 ['num_hosts', 'num_cores', 'job_name']) 1103 1104 1105class TPUEmbedding(object): 1106 """API for using TPU for embedding. 1107 1108 Example: 1109 ``` 1110 table_config_user = tpu_embedding.TableConfig( 1111 vocabulary_size=4, dimension=2, 1112 initializer=initializer, combiner='mean') 1113 table_to_config_dict = {'video': table_config_video, 1114 'user': table_config_user} 1115 feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'), 1116 'favorited': tpu_embedding.FeatureConfig('video'), 1117 'friends': tpu_embedding.FeatureConfig('user')} 1118 batch_size = 4 1119 num_hosts = 1 1120 optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) 1121 mode = tpu_embedding.TRAINING 1122 embedding = tpu_embedding.TPUEmbedding( 1123 table_to_config_dict, feature_to_config_dict, 1124 batch_size, num_hosts, mode, optimization_parameters) 1125 1126 batch_size_per_core = embedding.batch_size_per_core 1127 sparse_features_list = [] 1128 for host in hosts: 1129 with ops.device(host): 1130 for _ in range(embedding.num_cores_per_host): 1131 sparse_features = {} 1132 sparse_features['watched'] = sparse_tensor.SparseTensor(...) 1133 sparse_features['favorited'] = sparse_tensor.SparseTensor(...) 1134 sparse_features['friends'] = sparse_tensor.SparseTensor(...) 1135 sparse_features_list.append(sparse_features) 1136 1137 enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) 1138 embedding_variables_and_ops = embedding.create_variables_and_ops() 1139 1140 def computation(): 1141 activations = embedding.get_activations() 1142 loss = compute_loss(activations) 1143 1144 base_optimizer = gradient_descent.GradientDescentOptimizer( 1145 learning_rate=1) 1146 cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer( 1147 base_optimizer) 1148 1149 train_op = cross_shard_optimizer.minimize(loss) 1150 gradients = ( 1151 tpu_embedding_gradient.get_gradients_through_compute_gradients( 1152 cross_shard_optimizer, loss, activations) 1153 send_gradients_op = embedding.generate_send_gradients_op(gradients) 1154 with ops.control_dependencies([train_op, send_gradients_op]): 1155 loss = array_ops.identity(loss) 1156 1157 loss = tpu.shard(computation, 1158 num_shards=embedding.num_cores) 1159 1160 with self.test_session() as sess: 1161 sess.run(tpu.initialize_system(embedding_config= 1162 embedding.config_proto)) 1163 sess.run(variables.global_variables_initializer()) 1164 sess.run(embedding_variables_and_ops.load_ops()) 1165 sess.run(enqueue_ops) 1166 loss_val = sess.run(loss) 1167 ``` 1168 1169 Example with weight decay: 1170 1171 >>> def learning_rate_fn(global_step): 1172 ... return tf.compat.v1.train.polynomial_decay( 1173 ... learning_rate=5e-5, 1174 ... global_step=global_step, 1175 ... decay_steps=100000, 1176 ... end_learning_rate=0.0) 1177 >>> wordpiece_table_config = TableConfig( 1178 ... vocabulary_size=119547, 1179 ... dimension=256, 1180 ... learning_rate_fn=learning_rate_fn) 1181 >>> wordpiece_feature_config = FeatureConfig( 1182 ... table_id='bert/embeddings/word_embeddings', 1183 ... max_sequence_length=512) 1184 >>> optimization_parameters = AdamParameters( 1185 ... learning_rate=5e-5, 1186 ... epsilon=1e-6, 1187 ... weight_decay_factor=0.01, 1188 ... multiply_weight_decay_factor_by_learning_rate=True) 1189 >>> tpu_embedding = TPUEmbedding( 1190 ... table_to_config_dict={ 1191 ... 'bert/embeddings/word_embeddings': wordpiece_table_config, 1192 ... }, 1193 ... feature_to_config_dict={'input_ids': wordpiece_feature_config}, 1194 ... batch_size=128, 1195 ... mode=TRAINING, 1196 ... optimization_parameters=optimization_parameters, 1197 ... master='') 1198 >>> with tf.Graph().as_default(): 1199 ... init_tpu_op = tf.compat.v1.tpu.initialize_system( 1200 ... embedding_config=tpu_embedding.config_proto) 1201 ... tf.compat.v1.Session().run(init_tpu_op) 1202 """ 1203 1204 # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that 1205 # the feature should not be used to update embedding table (cr/204852758, 1206 # cr/204940540). Also, this can support different combiners for different 1207 # features within the same table. 1208 # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it 1209 # to `FeatureConfig`? 1210 1211 # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and 1212 # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec` 1213 # respectively? 1214 1215 # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate 1216 # for-loops around construction of inputs. 1217 1218 # `optimization_parameter` applies to all tables. If the need arises, 1219 # we can add `optimization_parameters` to `TableConfig` to override this 1220 # global setting. 1221 def __init__(self, 1222 table_to_config_dict, 1223 feature_to_config_dict, 1224 batch_size, 1225 mode, 1226 master=None, 1227 optimization_parameters=None, 1228 cluster_def=None, 1229 pipeline_execution_with_tensor_core=False, 1230 partition_strategy='div', 1231 profile_data_directory=None, 1232 device_config=None, 1233 master_job_name=None): 1234 """API for using TPU for embedding lookups. 1235 1236 Args: 1237 table_to_config_dict: A dictionary mapping from string of table name to 1238 `TableConfig`. Table refers to an embedding table, e.g. `params` 1239 argument to `tf.nn.embedding_lookup_sparse()`. 1240 feature_to_config_dict: A dictionary mapping from string of feature name 1241 to `FeatureConfig`. Feature refers to ids to lookup in embedding table, 1242 e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. 1243 batch_size: An `int` representing the global batch size. 1244 mode: `TRAINING` or `INFERENCE`. 1245 master: A `string` representing the TensorFlow master to use. 1246 optimization_parameters: `AdagradParameters`, `AdamParameters`, 1247 `Stochasticgradientdescentparameters`. Must be set in training unless 1248 all tables specify their own optimizers. And it must be `None` in 1249 inference. 1250 cluster_def: A ClusterDef object describing the TPU cluster. 1251 pipeline_execution_with_tensor_core: setting this to `True` makes training 1252 faster, but trained model will be different if step N and step N+1 1253 involve the same set of embedding IDs. Please see 1254 `tpu_embedding_configuration.proto` for details. 1255 partition_strategy: A string, either 'mod' or 'div', specifying how to map 1256 the lookup id to the embedding tensor. For more information see 1257 `tf.nn.embedding_lookup_sparse`. 1258 profile_data_directory: Directory where embedding lookup statistics are 1259 stored. These statistics summarize information about the inputs to the 1260 embedding lookup operation, in particular, the average number of 1261 embedding IDs per example and how well the embedding IDs are load 1262 balanced across the system. The lookup statistics are used during TPU 1263 initialization for embedding table partitioning. Collection of lookup 1264 statistics is done at runtime by profiling the embedding inputs: only 1265 3% of input samples are profiled to minimize host CPU overhead. Once 1266 a suitable number of samples are profiled, the lookup statistics are 1267 saved to table-specific files in the profile data directory generally 1268 at the end of a TPU training loop. The filename corresponding to each 1269 table is obtained by hashing table specific parameters (e.g., table 1270 name and number of features) and global configuration parameters (e.g., 1271 sharding strategy and task count). The same profile data directory can 1272 be shared among several models to reuse embedding lookup statistics. 1273 device_config: A DeviceConfig instance, used when `master` and 1274 `cluster_def` are both `None`. 1275 master_job_name: if set, overrides the master job name used to schedule 1276 embedding ops. 1277 1278 Raises: 1279 ValueError: if any input is invalid. 1280 """ 1281 if partition_strategy not in ('div', 'mod'): 1282 raise ValueError( 1283 'Invalid partition_strategy {}'.format(partition_strategy)) 1284 self._partition_strategy = partition_strategy 1285 1286 self._profile_data_directory = profile_data_directory 1287 1288 _validate_table_to_config_dict(table_to_config_dict) 1289 # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. 1290 self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) 1291 1292 _validate_feature_to_config_dict(table_to_config_dict, 1293 feature_to_config_dict) 1294 self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict) 1295 self._table_to_features_dict, self._table_to_num_features_dict = ( 1296 _create_table_to_features_and_num_features_dicts( 1297 self._feature_to_config_dict)) 1298 self._combiners = _create_combiners(self._table_to_config_dict, 1299 self._table_to_features_dict) 1300 1301 self._batch_size = batch_size 1302 1303 if master is None and cluster_def is None: 1304 if device_config is None: 1305 raise ValueError('When master and cluster_def are both None,' 1306 'device_config must be set but is not.') 1307 if device_config.num_cores % device_config.num_hosts: 1308 raise ValueError('num_hosts ({}) should divide num_cores ({}) ' 1309 'but does not.'.format(device_config.num_cores, 1310 device_config.num_hosts)) 1311 self._num_hosts = device_config.num_hosts 1312 self._num_cores = device_config.num_cores 1313 self._num_cores_per_host = self._num_cores // self._num_hosts 1314 self._hosts = [ 1315 '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i) 1316 for i in range(self._num_hosts) 1317 ] 1318 else: 1319 tpu_system_metadata = ( 1320 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access 1321 master, 1322 cluster_def=cluster_def)) 1323 if tpu_system_metadata.num_cores == 0: 1324 raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' 1325 'TPUs.'.format(master)) 1326 self._num_hosts = tpu_system_metadata.num_hosts 1327 if master_job_name is None: 1328 try: 1329 master_job_name = tpu_system_metadata_lib.master_job( 1330 master, cluster_def) 1331 except ValueError as e: 1332 raise ValueError(str(e) + ' Please specify a master_job_name.') 1333 self._hosts = [] 1334 for device in tpu_system_metadata.devices: 1335 if 'device:CPU:' in device.name and (master_job_name is None or 1336 master_job_name in device.name): 1337 self._hosts.append(device.name) 1338 self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host 1339 self._num_cores = tpu_system_metadata.num_cores 1340 1341 _validate_batch_size(self._batch_size, self._num_cores) 1342 self._batch_size_per_core = self._batch_size // self._num_cores 1343 1344 # TODO(shizhiw): remove `mode`? 1345 if mode == TRAINING: 1346 _validate_optimization_parameters(optimization_parameters, 1347 self._table_to_config_dict) 1348 self._optimization_parameters = optimization_parameters 1349 elif mode == INFERENCE: 1350 if optimization_parameters is not None: 1351 raise ValueError('`optimization_parameters` should be `None` ' 1352 'for inference mode.') 1353 self._optimization_parameters = (StochasticGradientDescentParameters(1.)) 1354 else: 1355 raise ValueError('`mode` only supports {} and {}; got {}.'.format( 1356 TRAINING, INFERENCE, mode)) 1357 self._mode = mode 1358 1359 # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` 1360 # and create special handler for inference that inherits from 1361 # StochasticGradientDescentHandler with more user-friendly error message 1362 # on get_slot(). 1363 self._optimizer_handler_dict = self._get_optimizer_handler_by_table() 1364 1365 self._pipeline_execution_with_tensor_core = ( 1366 pipeline_execution_with_tensor_core) 1367 self._learning_rate_fn = list( 1368 set(c.learning_rate_fn 1369 for c in self._table_to_config_dict.values() 1370 if c.learning_rate_fn is not None)) 1371 self._learning_rate_fn_to_tag = { 1372 fn: id for id, fn in enumerate(self._learning_rate_fn) 1373 } 1374 1375 self._config_proto = self._create_config_proto() 1376 1377 @property 1378 def hosts(self): 1379 """A list of device names for CPU hosts. 1380 1381 Returns: 1382 A list of device names for CPU hosts. 1383 """ 1384 return copy.copy(self._hosts) 1385 1386 # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and 1387 # to be consistent with `tpu_embedding_configuration.proto`. 1388 @property 1389 def num_cores_per_host(self): 1390 """Number of TPU cores on a CPU host. 1391 1392 Returns: 1393 Number of TPU cores on a CPU host. 1394 """ 1395 return self._num_cores_per_host 1396 1397 @property 1398 def num_cores(self): 1399 """Total number of TPU cores on all hosts. 1400 1401 Returns: 1402 Total number of TPU cores on all hosts. 1403 """ 1404 return self._num_cores 1405 1406 @property 1407 def batch_size_per_core(self): 1408 """Batch size for each TPU core. 1409 1410 The sparse tensors in `sparse_features_list` to `generate_enqueue_ops` 1411 must have batch dimension equal to this. 1412 1413 Returns: 1414 Batch size for each TPU core. 1415 """ 1416 return self._batch_size_per_core 1417 1418 @property 1419 def config_proto(self): 1420 """Create embedding config proto for `tpu.initialize_system()`. 1421 1422 Returns: 1423 an `TPUEmbeddingConfiguration` proto describing the desired 1424 configuration of the hardware embedding lookup tables, which 1425 is passed to `tpu.initialize_system()`. 1426 """ 1427 return self._config_proto 1428 1429 @property 1430 def table_to_config_dict(self): 1431 return copy.copy(self._table_to_config_dict) 1432 1433 @property 1434 def feature_to_config_dict(self): 1435 return copy.copy(self._feature_to_config_dict) 1436 1437 @property 1438 def table_to_features_dict(self): 1439 return copy.copy(self._table_to_features_dict) 1440 1441 @property 1442 def optimization_parameters(self): 1443 return self._optimization_parameters 1444 1445 def _create_config_proto(self): 1446 """Create `TPUEmbeddingConfiguration`.""" 1447 config_proto = elc.TPUEmbeddingConfiguration() 1448 for table in self._table_to_config_dict: 1449 table_descriptor = config_proto.table_descriptor.add() 1450 table_descriptor.name = table 1451 1452 table_config = self._table_to_config_dict[table] 1453 # For small tables, we pad to the number of hosts so that at least one 1454 # id will be assigned to each host. 1455 table_descriptor.vocabulary_size = max(table_config.vocabulary_size, 1456 len(self.hosts)) 1457 table_descriptor.dimension = table_config.dimension 1458 1459 table_descriptor.num_features = self._table_to_num_features_dict[table] 1460 1461 optimization_parameters = ( 1462 self._optimizer_handler_dict[table].get_optimization_parameters()) 1463 1464 parameters = table_descriptor.optimization_parameters 1465 if table_config.learning_rate: 1466 parameters.learning_rate.constant = table_config.learning_rate 1467 elif table_config.learning_rate_fn: 1468 parameters.learning_rate.dynamic.tag = ( 1469 self._learning_rate_fn_to_tag[table_config.learning_rate_fn]) 1470 else: 1471 parameters.learning_rate.constant = ( 1472 optimization_parameters.learning_rate) 1473 parameters.gradient_accumulation_status = ( 1474 optimization_parameters_pb2.GradientAccumulationStatus.ENABLED 1475 if optimization_parameters.use_gradient_accumulation else 1476 optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) 1477 1478 if optimization_parameters.clip_gradient_min is not None: 1479 parameters.gradient_clipping_limits.lower.value = ( 1480 optimization_parameters.clip_gradient_min) 1481 if optimization_parameters.clip_gradient_max is not None: 1482 parameters.gradient_clipping_limits.upper.value = ( 1483 optimization_parameters.clip_gradient_max) 1484 1485 if optimization_parameters.clip_weight_min is not None: 1486 parameters.clipping_limits.lower.value = ( 1487 optimization_parameters.clip_weight_min) 1488 if optimization_parameters.clip_weight_max is not None: 1489 parameters.clipping_limits.upper.value = ( 1490 optimization_parameters.clip_weight_max) 1491 if optimization_parameters.weight_decay_factor: 1492 parameters.weight_decay_factor = ( 1493 optimization_parameters.weight_decay_factor) 1494 if (optimization_parameters 1495 .multiply_weight_decay_factor_by_learning_rate): 1496 parameters.multiply_weight_decay_factor_by_learning_rate = True 1497 if table_config.hot_id_replication: 1498 parameters.hot_id_replication_configuration.status = ( 1499 optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED) 1500 optimizer_handler = self._optimizer_handler_dict[table] 1501 optimizer_handler.set_optimization_parameters(table_descriptor) 1502 1503 config_proto.mode = self._mode 1504 config_proto.batch_size_per_tensor_core = self._batch_size_per_core 1505 config_proto.num_hosts = self._num_hosts 1506 config_proto.num_tensor_cores = self._num_cores 1507 config_proto.sharding_strategy = ( 1508 elc.TPUEmbeddingConfiguration.DIV_DEFAULT 1509 if self._partition_strategy == 'div' else 1510 elc.TPUEmbeddingConfiguration.MOD) 1511 config_proto.pipeline_execution_with_tensor_core = ( 1512 self._pipeline_execution_with_tensor_core) 1513 if self._profile_data_directory: 1514 config_proto.profile_data_directory = self._profile_data_directory 1515 1516 return config_proto 1517 1518 def create_variables_and_ops(self, 1519 embedding_variable_name_by_table=None, 1520 slot_variable_names_by_table=None): 1521 """Create embedding and slot variables, with ops to load and retrieve them. 1522 1523 N.B.: the retrieve embedding variables (including slot variables) ops are 1524 returned as lambda fn, as the call side might want to impose control 1525 dependencies between the TPU computation and retrieving actions. For 1526 example, the following code snippet ensures the TPU computation finishes 1527 first, and then we pull the variables back from TPU to CPU. 1528 1529 ``` 1530 updates_ops = [] 1531 with ops.control_dependencies([loss]): 1532 for op_fn in retrieve_parameters_op_fns: 1533 update_ops.append(op_fn()) 1534 ``` 1535 1536 Args: 1537 embedding_variable_name_by_table: A dictionary mapping from string of 1538 table name to string of embedding variable name. If `None`, defaults 1539 from `get_default_slot_variable_names()` will be used. 1540 slot_variable_names_by_table: A dictionary mapping from string of table 1541 name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If 1542 `None`, defaults from `get_default_slot_variable_names()` will be used. 1543 1544 Returns: 1545 `tpu_embedding.VariablesAndOps` with: 1546 A dictionary mapping from string of table name to embedding variables, 1547 A dictionary mapping from string of table name to AdagradSlotVariable, 1548 AdamSlotVariables etc with slot variables, 1549 A function which returns a list of ops to load embedding and slot 1550 variables from CPU to TPU. 1551 A function which returns a list of ops to retrieve embedding and slot 1552 variables from TPU to CPU. 1553 """ 1554 embedding_variables_by_table = {} 1555 slot_variables_by_table = {} 1556 load_op_fns = [] 1557 retrieve_op_fns = [] 1558 1559 for i, table in enumerate(self._table_to_config_dict): 1560 if embedding_variable_name_by_table: 1561 embedding_variable_name = embedding_variable_name_by_table[table] 1562 else: 1563 embedding_variable_name = table 1564 if slot_variable_names_by_table: 1565 slot_variable_names = slot_variable_names_by_table[table] 1566 else: 1567 optimizer_handler = self._optimizer_handler_dict[table] 1568 slot_variable_names = ( 1569 optimizer_handler.get_default_slot_variable_names(table)) 1570 1571 # TODO(b/139144091): Multi-host support for mid-level API in 1572 # eager context (TF 2.0) 1573 # Workaround below allows single-host use case in TF 2.0 1574 if context.executing_eagerly(): 1575 device = '' 1576 else: 1577 device = _create_device_fn(self._hosts) 1578 1579 with ops.device(device): 1580 table_variables = _create_partitioned_variables( 1581 name=embedding_variable_name, 1582 num_hosts=self._num_hosts, 1583 vocabulary_size=self._table_to_config_dict[table].vocabulary_size, 1584 embedding_dimension=self._table_to_config_dict[table].dimension, 1585 initializer=self._table_to_config_dict[table].initializer, 1586 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 1587 embedding_variables_by_table[table] = table_variables 1588 1589 # Only loads embedding config to load/retrieve nodes for the first table 1590 # on the first host, other nodes would use config from the first node. 1591 config = None if i else self.config_proto.SerializeToString() 1592 slot_variables_for_table, load_ops_fn, retrieve_ops_fn = ( 1593 self._optimizer_handler_dict[table].create_variables_and_ops( 1594 table, slot_variable_names, self._num_hosts, 1595 self._table_to_config_dict[table], table_variables, config)) 1596 slot_variables_by_table[table] = slot_variables_for_table 1597 load_op_fns.append(load_ops_fn) 1598 retrieve_op_fns.append(retrieve_ops_fn) 1599 1600 def load_ops(): 1601 """Calls and returns the load ops for each embedding table. 1602 1603 Returns: 1604 A list of ops to load embedding and slot variables from CPU to TPU. 1605 """ 1606 load_ops_list = [] 1607 for load_op_fn in load_op_fns: 1608 load_ops_list.extend(load_op_fn()) 1609 return load_ops_list 1610 1611 def retrieve_ops(): 1612 """Calls and returns the retrieve ops for each embedding table. 1613 1614 Returns: 1615 A list of ops to retrieve embedding and slot variables from TPU to CPU. 1616 """ 1617 retrieve_ops_list = [] 1618 for retrieve_op_fn in retrieve_op_fns: 1619 retrieve_ops_list.extend(retrieve_op_fn()) 1620 return retrieve_ops_list 1621 1622 return VariablesAndOps(embedding_variables_by_table, 1623 slot_variables_by_table, load_ops, retrieve_ops) 1624 1625 def generate_enqueue_ops( 1626 self, 1627 enqueue_datas_list, 1628 mode_override=None, 1629 ragged=False, 1630 ): 1631 """Generate enqueue ops. 1632 1633 Args: 1634 enqueue_datas_list: a list of dictionary mapping from string of feature 1635 names to EnqueueData. Each dictionary is for one TPU core. Dictionaries 1636 for the same host should be contiguous in the list. 1637 mode_override: A string input that overrides the mode specified in the 1638 TPUEmbeddingConfiguration. Supported values are {'unspecified', 1639 'inference', 'training', 'backward_pass_only'}. When set to 1640 'unspecified', the mode set in TPUEmbeddingConfiguration is used, 1641 otherwise mode_override is used (optional). 1642 ragged: If True, creates RaggedTensor enqueue ops rather than 1643 SparseTensor. 1644 1645 Returns: 1646 Ops to enqueue to TPU for embedding. 1647 """ 1648 self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list) 1649 return [ 1650 self._generate_enqueue_op( # pylint: disable=g-complex-comprehension 1651 enqueue_datas, 1652 device_ordinal=i % self._num_cores_per_host, 1653 mode_override=mode_override, 1654 ragged=ragged, 1655 ) for i, enqueue_datas in enumerate(enqueue_datas_list) 1656 ] 1657 1658 def _validate_generate_enqueue_ops_enqueue_datas_list(self, 1659 enqueue_datas_list): 1660 """Validate `enqueue_datas_list`.""" 1661 1662 def _check_agreement(data, name, feature, enqueue_data): 1663 """Helper function to check device agreement.""" 1664 if (data is not None and 1665 data.device != enqueue_data.embedding_indices.device): 1666 raise ValueError('Device of {0} does not agree with that of' 1667 'embedding_indices for feature {1}.'.format( 1668 name, feature)) 1669 1670 feature_set = set(self._feature_to_config_dict.keys()) 1671 contiguous_device = None 1672 for i, enqueue_datas in enumerate(enqueue_datas_list): 1673 used_feature_set = set(enqueue_datas.keys()) 1674 1675 # Check features are valid. 1676 missing_feature_set = feature_set - used_feature_set 1677 if missing_feature_set: 1678 raise ValueError('`enqueue_datas_list[{}]` misses a feature that is ' 1679 'in `feature_to_config_dict`: {}.'.format( 1680 i, missing_feature_set)) 1681 1682 extra_feature_set = used_feature_set - feature_set 1683 if extra_feature_set: 1684 raise ValueError('`enqueue_datas_list[{}]` has a feature that is not ' 1685 'in `feature_to_config_dict`: {}.'.format( 1686 i, extra_feature_set)) 1687 1688 device = None 1689 device_feature = None 1690 for feature, enqueue_data in six.iteritems(enqueue_datas): 1691 combiner = self._table_to_config_dict[ 1692 self._feature_to_config_dict[feature].table_id].combiner 1693 1694 if isinstance(enqueue_data, EnqueueData): 1695 if enqueue_data.sample_indices is None and combiner: 1696 logging.warn( 1697 'No sample indices set for features %f table %f but ' 1698 'combiner is set to %s.', feature, 1699 self._feature_to_config_dict[feature].table_id, combiner) 1700 _check_agreement(enqueue_data.sample_indices, 'sample_indices', 1701 feature, enqueue_data) 1702 _check_agreement(enqueue_data.aggregation_weights, 1703 'aggregation_weights', feature, enqueue_data) 1704 1705 elif isinstance(enqueue_data, RaggedEnqueueData): 1706 if enqueue_data.sample_splits is None and combiner: 1707 logging.warn( 1708 'No sample splits set for features %f table %f but ' 1709 'combiner is set to %s.', feature, 1710 self._feature_to_config_dict[feature].table_id, combiner) 1711 _check_agreement(enqueue_data.sample_splits, 'sample_splits', feature, 1712 enqueue_data) 1713 _check_agreement(enqueue_data.aggregation_weights, 1714 'aggregation_weights', feature, enqueue_data) 1715 else: 1716 raise ValueError( 1717 '`enqueue_datas_list[{}]` has a feature that is not mapped to ' 1718 '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format( 1719 i, feature)) 1720 # Check all features are on the same device. 1721 if device is None: 1722 device = enqueue_data.embedding_indices.device 1723 device_feature = feature 1724 else: 1725 if device != enqueue_data.embedding_indices.device: 1726 raise ValueError('Devices are different between features in ' 1727 '`enqueue_datas_list[{}]`; ' 1728 'devices: {}, {}; features: {}, {}.'.format( 1729 i, device, 1730 enqueue_data.embedding_indices.device, feature, 1731 device_feature)) 1732 1733 if i % self._num_cores_per_host: 1734 if device != contiguous_device: 1735 raise ValueError('We expect the `enqueue_datas` which are on the ' 1736 'same host to be contiguous in ' 1737 '`enqueue_datas_list`, ' 1738 '`enqueue_datas_list[{}]` is on device {}, ' 1739 'but is expected to be on device {}.'.format( 1740 i, device, contiguous_device)) 1741 else: 1742 contiguous_device = device 1743 1744 def _generate_enqueue_op(self, 1745 enqueue_datas, 1746 device_ordinal, 1747 mode_override=None, 1748 ragged=False): 1749 """Creates op for enqueuing batch to TPU.""" 1750 enqueue_data0 = list(enqueue_datas.values())[0] 1751 with ops.colocate_with(enqueue_data0.embedding_indices): 1752 if ragged: 1753 # note that this is currently identical in behavior 1754 return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( 1755 device_ordinal=device_ordinal, 1756 combiners=self._combiners, 1757 mode_override=mode_override, 1758 **self._format_for_tpu_embedding_ragged_tensor_batch(enqueue_datas)) 1759 else: 1760 return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 1761 device_ordinal=device_ordinal, 1762 combiners=self._combiners, 1763 mode_override=mode_override, 1764 **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)) 1765 1766 def _format_for_tpu_embedding_ragged_tensor_batch(self, enqueue_datas): 1767 """Format sparse features for `enqueue_tpu_embedding_ragged_tensor_batch()`. 1768 1769 Args: 1770 enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding. 1771 1772 Returns: 1773 Dict of arguments for `enqueue_tpu_embedding_ragged_tensor_batch()`. 1774 """ 1775 1776 kwargs = { 1777 'sample_splits': [], 1778 'embedding_indices': [], 1779 'aggregation_weights': [], 1780 'table_ids': [], 1781 'max_sequence_lengths': [], 1782 } 1783 int_zeros = array_ops.zeros((0,), dtype=dtypes.int64) 1784 float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) 1785 for table_id, table in enumerate(self._table_to_features_dict): 1786 features = self._table_to_features_dict[table] 1787 for feature in features: 1788 enqueue_data = enqueue_datas[feature] 1789 1790 kwargs['sample_splits'].append( 1791 enqueue_data.sample_splits 1792 if enqueue_data.sample_splits is not None else int_zeros) 1793 1794 kwargs['aggregation_weights'].append( 1795 enqueue_data.aggregation_weights 1796 if enqueue_data.aggregation_weights is not None else float_zeros) 1797 1798 kwargs['embedding_indices'].append(enqueue_data.embedding_indices) 1799 1800 kwargs['table_ids'].append(table_id) 1801 kwargs['max_sequence_lengths'].append( 1802 self._feature_to_config_dict[feature].max_sequence_length) 1803 1804 return kwargs 1805 1806 def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas): 1807 """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`. 1808 1809 Args: 1810 enqueue_datas: a `Dict` of `EnqueueData` objects for embedding. 1811 1812 Returns: 1813 Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`. 1814 """ 1815 kwargs = { 1816 'sample_indices': [], 1817 'embedding_indices': [], 1818 'aggregation_weights': [], 1819 'table_ids': [], 1820 'max_sequence_lengths': [], 1821 } 1822 int_zeros = array_ops.zeros((0,), dtype=dtypes.int64) 1823 float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) 1824 for table_id, table in enumerate(self._table_to_features_dict): 1825 features = self._table_to_features_dict[table] 1826 for feature in features: 1827 enqueue_data = enqueue_datas[feature] 1828 1829 kwargs['sample_indices'].append( 1830 enqueue_data.sample_indices 1831 if enqueue_data.sample_indices is not None else int_zeros) 1832 1833 kwargs['aggregation_weights'].append( 1834 enqueue_data.aggregation_weights 1835 if enqueue_data.aggregation_weights is not None else float_zeros) 1836 1837 kwargs['embedding_indices'].append(enqueue_data.embedding_indices) 1838 1839 kwargs['table_ids'].append(table_id) 1840 kwargs['max_sequence_lengths'].append( 1841 self._feature_to_config_dict[feature].max_sequence_length) 1842 1843 return kwargs 1844 1845 def get_activations(self): 1846 """Get activations for features. 1847 1848 This should be called within `computation` that is passed to 1849 `tpu.replicate` and friends. 1850 1851 Returns: 1852 A dictionary mapping from `String` of feature name to `Tensor` 1853 of activation. 1854 """ 1855 recv_activations = tpu_ops.recv_tpu_embedding_activations( 1856 num_outputs=len(self._table_to_config_dict), 1857 config=self._config_proto.SerializeToString()) 1858 1859 activations = collections.OrderedDict() 1860 for table_id, table in enumerate(self._table_to_features_dict): 1861 features = self._table_to_features_dict[table] 1862 num_features = self._table_to_num_features_dict[table] 1863 feature_index = 0 1864 table_activations = array_ops.reshape( 1865 recv_activations[table_id], 1866 [self.batch_size_per_core, num_features, -1]) 1867 for feature in features: 1868 seq_length = self._feature_to_config_dict[feature].max_sequence_length 1869 if not seq_length: 1870 activations[feature] = table_activations[:, feature_index, :] 1871 feature_index = feature_index + 1 1872 else: 1873 activations[feature] = ( 1874 table_activations[:, 1875 feature_index:(feature_index + seq_length), :]) 1876 feature_index = feature_index + seq_length 1877 1878 return activations 1879 1880 def generate_send_gradients_op(self, feature_to_gradient_dict, step=None): 1881 """Send gradient to TPU embedding. 1882 1883 Args: 1884 feature_to_gradient_dict: dict mapping feature names to gradient wrt 1885 activations. 1886 step: the current global step, used for dynamic learning rate. 1887 1888 Returns: 1889 SendTPUEmbeddingGradients Op. 1890 1891 Raises: 1892 RuntimeError: If `mode` is not `TRAINING`. 1893 """ 1894 if self._mode != TRAINING: 1895 raise RuntimeError('Only in training mode gradients need to ' 1896 'be sent to TPU embedding; got mode {}.'.format( 1897 self._mode)) 1898 if step is None and self._learning_rate_fn: 1899 raise ValueError('There are dynamic learning rates but step is None.') 1900 1901 gradients = [] 1902 for table in self._table_to_features_dict: 1903 features = self._table_to_features_dict[table] 1904 table_gradients = [] 1905 for feature in features: 1906 gradient = feature_to_gradient_dict[feature] 1907 # Expand dims for non-sequence feature to match sequence features. 1908 if gradient.shape.ndims == 2: 1909 gradient = array_ops.expand_dims(gradient, 1) 1910 table_gradients.append(gradient) 1911 interleaved_table_grads = array_ops.reshape( 1912 array_ops.concat(table_gradients, axis=1), 1913 [-1, array_ops.shape(table_gradients[0])[-1]]) 1914 gradients.append(interleaved_table_grads) 1915 1916 return tpu_ops.send_tpu_embedding_gradients( 1917 inputs=gradients, 1918 learning_rates=[ 1919 math_ops.cast(fn(step), dtype=dtypes.float32) 1920 for fn in self._learning_rate_fn 1921 ], 1922 config=self.config_proto.SerializeToString()) 1923 1924 def _get_optimizer_handler_by_table(self): 1925 optimizer_handlers = {} 1926 for table, table_config in self.table_to_config_dict.items(): 1927 if table_config.optimization_parameters is not None: 1928 optimizer = table_config.optimization_parameters 1929 else: 1930 optimizer = self._optimization_parameters 1931 optimizer_handlers[table] = _get_optimization_handler(optimizer) 1932 1933 return optimizer_handlers 1934 1935 1936def _validate_table_to_config_dict(table_to_config_dict): 1937 """Validate `table_to_config_dict`.""" 1938 for k, v in six.iteritems(table_to_config_dict): 1939 if not isinstance(v, TableConfig): 1940 raise ValueError('Value of `table_to_config_dict` must be of type ' 1941 '`TableConfig`, got {} for {}.'.format(type(v), k)) 1942 1943 1944def _validate_feature_to_config_dict(table_to_config_dict, 1945 feature_to_config_dict): 1946 """Validate `feature_to_config_dict`.""" 1947 used_table_set = set( 1948 [feature.table_id for feature in feature_to_config_dict.values()]) 1949 table_set = set(table_to_config_dict.keys()) 1950 1951 unused_table_set = table_set - used_table_set 1952 if unused_table_set: 1953 raise ValueError( 1954 '`table_to_config_dict` specifies table that is not ' 1955 'used in `feature_to_config_dict`: {}.'.format(unused_table_set)) 1956 1957 extra_table_set = used_table_set - table_set 1958 if extra_table_set: 1959 raise ValueError( 1960 '`feature_to_config_dict` refers to a table that is not ' 1961 'specified in `table_to_config_dict`: {}.'.format(extra_table_set)) 1962 1963 1964def _validate_batch_size(batch_size, num_cores): 1965 if batch_size % num_cores: 1966 raise ValueError('`batch_size` is not a multiple of number of ' 1967 'cores. `batch_size`={}, `_num_cores`={}.'.format( 1968 batch_size, num_cores)) 1969 1970 1971def _validate_optimization_parameters(optimization_parameters, 1972 table_to_config_dict): 1973 """Validate global optimization_parameters and per table optimizers. 1974 1975 If global optimizer is `None`, all table optimizers should be non `None`. 1976 1977 Args: 1978 optimization_parameters: global optimizer provided in `TPUEmbedding` 1979 constructor. 1980 table_to_config_dict: A dictionary mapping from string of table name to 1981 `TableConfig`. 1982 """ 1983 tbl_optimizer_missing = False 1984 for _, table_config in table_to_config_dict.items(): 1985 if table_config.optimization_parameters is None: 1986 tbl_optimizer_missing = True 1987 break 1988 1989 if optimization_parameters: 1990 if not isinstance(optimization_parameters, _OptimizationParameters): 1991 raise ValueError('`optimization_parameters` must inherit from ' 1992 '`_OptimizationParameters`. ' 1993 '`type(optimization_parameters)`={}'.format( 1994 type(optimization_parameters))) 1995 else: 1996 # Missing global optimization_parameters. 1997 if tbl_optimizer_missing: 1998 raise ValueError('`optimization_parameters` is missing.') 1999 2000 2001class _OptimizerHandler(object): 2002 """Interface class for handling optimizer specific logic.""" 2003 2004 def __init__(self, optimization_parameters): 2005 self._optimization_parameters = optimization_parameters 2006 2007 def get_optimization_parameters(self): 2008 return self._optimization_parameters 2009 2010 def set_optimization_parameters(self, table_descriptor): 2011 raise NotImplementedError() 2012 2013 def get_default_slot_variable_names(self, table): 2014 raise NotImplementedError() 2015 2016 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2017 table_config, table_variables, config_proto): 2018 raise NotImplementedError() 2019 2020 2021class _AdagradHandler(_OptimizerHandler): 2022 """Handles Adagrad specific logic.""" 2023 2024 def set_optimization_parameters(self, table_descriptor): 2025 table_descriptor.optimization_parameters.adagrad.SetInParent() 2026 2027 def get_default_slot_variable_names(self, table): 2028 return AdagradSlotVariableName('{}/{}'.format(table, 'Adagrad')) 2029 2030 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2031 table_config, table_variables, config_proto): 2032 accumulator_initializer = init_ops.constant_initializer( 2033 self._optimization_parameters.initial_accumulator) 2034 accumulator_variables = _create_partitioned_variables( 2035 name=slot_variable_names.accumulator, 2036 num_hosts=num_hosts, 2037 vocabulary_size=table_config.vocabulary_size, 2038 embedding_dimension=table_config.dimension, 2039 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2040 initializer=accumulator_initializer) 2041 slot_variables = AdagradSlotVariable(accumulator_variables) 2042 2043 def load_ops_fn(): 2044 """Returns the retrieve ops for AdaGrad embedding tables. 2045 2046 Returns: 2047 A list of ops to load embedding and slot variables from CPU to TPU. 2048 """ 2049 config = config_proto 2050 load_op_list = [] 2051 for host_id, table_variable, accumulator_variable in zip( 2052 range(num_hosts), table_variables, accumulator_variables): 2053 with ops.colocate_with(table_variable): 2054 load_parameters_op = ( 2055 tpu_ops.load_tpu_embedding_adagrad_parameters( 2056 parameters=table_variable, 2057 accumulators=accumulator_variable, 2058 table_name=table, 2059 num_shards=num_hosts, 2060 shard_id=host_id, 2061 config=config)) 2062 config = None 2063 load_op_list.append(load_parameters_op) 2064 return load_op_list 2065 2066 def retrieve_ops_fn(): 2067 """Returns the retrieve ops for AdaGrad embedding tables. 2068 2069 Returns: 2070 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2071 """ 2072 config = config_proto 2073 retrieve_op_list = [] 2074 for host_id, table_variable, accumulator_variable in (zip( 2075 range(num_hosts), table_variables, accumulator_variables)): 2076 with ops.colocate_with(table_variable): 2077 retrieved_table, retrieved_accumulator = ( 2078 tpu_ops.retrieve_tpu_embedding_adagrad_parameters( 2079 table_name=table, 2080 num_shards=num_hosts, 2081 shard_id=host_id, 2082 config=config)) 2083 retrieve_parameters_op = control_flow_ops.group( 2084 state_ops.assign(table_variable, retrieved_table), 2085 state_ops.assign(accumulator_variable, retrieved_accumulator)) 2086 config = None 2087 retrieve_op_list.append(retrieve_parameters_op) 2088 return retrieve_op_list 2089 2090 return slot_variables, load_ops_fn, retrieve_ops_fn 2091 2092 2093class _ProximalAdagradHandler(_OptimizerHandler): 2094 """Handles ProximalAdagrad specific logic.""" 2095 2096 def set_optimization_parameters(self, table_descriptor): 2097 table_descriptor.optimization_parameters.proximal_adagrad.SetInParent() 2098 table_descriptor.optimization_parameters.proximal_adagrad.l1 = ( 2099 self._optimization_parameters.l1_regularization_strength) 2100 table_descriptor.optimization_parameters.proximal_adagrad.l2 = ( 2101 self._optimization_parameters.l2_regularization_strength) 2102 2103 def get_default_slot_variable_names(self, table): 2104 return ProximalAdagradSlotVariableName('{}/{}'.format( 2105 table, 'ProximalAdagrad')) 2106 2107 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2108 table_config, table_variables, config_proto): 2109 accumulator_initializer = init_ops.constant_initializer( 2110 self._optimization_parameters.initial_accumulator) 2111 accumulator_variables = _create_partitioned_variables( 2112 name=slot_variable_names.accumulator, 2113 num_hosts=num_hosts, 2114 vocabulary_size=table_config.vocabulary_size, 2115 embedding_dimension=table_config.dimension, 2116 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2117 initializer=accumulator_initializer) 2118 slot_variables = ProximalAdagradSlotVariable(accumulator_variables) 2119 2120 def load_ops_fn(): 2121 """Returns the retrieve ops for Proximal AdaGrad embedding tables. 2122 2123 Returns: 2124 A list of ops to load embedding and slot variables from CPU to TPU. 2125 """ 2126 config = config_proto 2127 load_op_list = [] 2128 for host_id, table_variable, accumulator_variable in zip( 2129 range(num_hosts), table_variables, accumulator_variables): 2130 with ops.colocate_with(table_variable): 2131 load_parameters_op = ( 2132 tpu_ops.load_tpu_embedding_proximal_adagrad_parameters( 2133 parameters=table_variable, 2134 accumulators=accumulator_variable, 2135 table_name=table, 2136 num_shards=num_hosts, 2137 shard_id=host_id, 2138 config=config)) 2139 config = None 2140 load_op_list.append(load_parameters_op) 2141 return load_op_list 2142 2143 def retrieve_ops_fn(): 2144 """Returns the retrieve ops for Proximal AdaGrad embedding tables. 2145 2146 Returns: 2147 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2148 """ 2149 config = config_proto 2150 retrieve_op_list = [] 2151 for host_id, table_variable, accumulator_variable in (zip( 2152 range(num_hosts), table_variables, accumulator_variables)): 2153 with ops.colocate_with(table_variable): 2154 retrieved_table, retrieved_accumulator = ( 2155 tpu_ops.retrieve_tpu_embedding_proximal_adagrad_parameters( 2156 table_name=table, 2157 num_shards=num_hosts, 2158 shard_id=host_id, 2159 config=config)) 2160 retrieve_parameters_op = control_flow_ops.group( 2161 state_ops.assign(table_variable, retrieved_table), 2162 state_ops.assign(accumulator_variable, retrieved_accumulator)) 2163 config = None 2164 retrieve_op_list.append(retrieve_parameters_op) 2165 return retrieve_op_list 2166 2167 return slot_variables, load_ops_fn, retrieve_ops_fn 2168 2169 2170class _AdamHandler(_OptimizerHandler): 2171 """Handles Adam specific logic.""" 2172 2173 def set_optimization_parameters(self, table_descriptor): 2174 table_descriptor.optimization_parameters.adam.beta1 = ( 2175 self._optimization_parameters.beta1) 2176 table_descriptor.optimization_parameters.adam.beta2 = ( 2177 self._optimization_parameters.beta2) 2178 table_descriptor.optimization_parameters.adam.epsilon = ( 2179 self._optimization_parameters.epsilon) 2180 table_descriptor.optimization_parameters.adam.use_non_lazy_adam = ( 2181 not self._optimization_parameters.lazy_adam) 2182 table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( 2183 self._optimization_parameters.sum_inside_sqrt) 2184 2185 def get_default_slot_variable_names(self, table): 2186 return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'), 2187 '{}/{}/v'.format(table, 'Adam')) 2188 2189 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2190 table_config, table_variables, config_proto): 2191 m_initializer = init_ops.zeros_initializer() 2192 m_variables = _create_partitioned_variables( 2193 name=slot_variable_names.m, 2194 num_hosts=num_hosts, 2195 vocabulary_size=table_config.vocabulary_size, 2196 embedding_dimension=table_config.dimension, 2197 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2198 initializer=m_initializer) 2199 v_initializer = init_ops.zeros_initializer() 2200 v_variables = _create_partitioned_variables( 2201 name=slot_variable_names.v, 2202 num_hosts=num_hosts, 2203 vocabulary_size=table_config.vocabulary_size, 2204 embedding_dimension=table_config.dimension, 2205 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2206 initializer=v_initializer) 2207 slot_variables = AdamSlotVariables(m_variables, v_variables) 2208 2209 def load_ops_fn(): 2210 """Returns the retrieve ops for AdaGrad embedding tables. 2211 2212 Returns: 2213 A list of ops to load embedding and slot variables from CPU to TPU. 2214 """ 2215 load_op_list = [] 2216 config = config_proto 2217 for host_id, table_variable, m_variable, v_variable in (zip( 2218 range(num_hosts), table_variables, m_variables, v_variables)): 2219 with ops.colocate_with(table_variable): 2220 load_parameters_op = ( 2221 tpu_ops.load_tpu_embedding_adam_parameters( 2222 parameters=table_variable, 2223 momenta=m_variable, 2224 velocities=v_variable, 2225 table_name=table, 2226 num_shards=num_hosts, 2227 shard_id=host_id, 2228 config=config)) 2229 # Set config to None to enforce that config is only loaded to the first 2230 # table. 2231 config = None 2232 load_op_list.append(load_parameters_op) 2233 return load_op_list 2234 2235 def retrieve_ops_fn(): 2236 """Returns the retrieve ops for Adam embedding tables. 2237 2238 Returns: 2239 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2240 """ 2241 retrieve_op_list = [] 2242 config = config_proto 2243 for host_id, table_variable, m_variable, v_variable in (zip( 2244 range(num_hosts), table_variables, m_variables, v_variables)): 2245 with ops.colocate_with(table_variable): 2246 retrieved_table, retrieved_m, retrieved_v = ( 2247 tpu_ops.retrieve_tpu_embedding_adam_parameters( 2248 table_name=table, 2249 num_shards=num_hosts, 2250 shard_id=host_id, 2251 config=config)) 2252 retrieve_parameters_op = control_flow_ops.group( 2253 state_ops.assign(table_variable, retrieved_table), 2254 state_ops.assign(m_variable, retrieved_m), 2255 state_ops.assign(v_variable, retrieved_v)) 2256 config = None 2257 retrieve_op_list.append(retrieve_parameters_op) 2258 return retrieve_op_list 2259 2260 return slot_variables, load_ops_fn, retrieve_ops_fn 2261 2262 2263class _FtrlHandler(_OptimizerHandler): 2264 """Handles Ftrl specific logic.""" 2265 2266 def set_optimization_parameters(self, table_descriptor): 2267 table_descriptor.optimization_parameters.ftrl.lr_power = ( 2268 self._optimization_parameters.learning_rate_power) 2269 table_descriptor.optimization_parameters.ftrl.l1 = ( 2270 self._optimization_parameters.l1_regularization_strength) 2271 table_descriptor.optimization_parameters.ftrl.l2 = ( 2272 self._optimization_parameters.l2_regularization_strength) 2273 table_descriptor.optimization_parameters.ftrl.multiply_linear_by_lr = ( 2274 self._optimization_parameters.multiply_linear_by_learning_rate) 2275 table_descriptor.optimization_parameters.ftrl.beta = ( 2276 self._optimization_parameters.beta) 2277 table_descriptor.optimization_parameters.ftrl.allow_zero_accumulator = ( 2278 self._optimization_parameters.allow_zero_accumulator) 2279 2280 def get_default_slot_variable_names(self, table): 2281 # These match the default slot variable names created by 2282 # tf.train.FtrlOptimizer. 2283 return FtrlSlotVariableName( 2284 '{}/{}'.format(table, 'Ftrl'), # accumulator 2285 '{}/{}'.format(table, 'Ftrl_1')) # linear 2286 2287 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2288 table_config, table_variables, config_proto): 2289 accumulator_initializer = init_ops.constant_initializer( 2290 self._optimization_parameters.initial_accumulator_value) 2291 accumulator_variables = _create_partitioned_variables( 2292 name=slot_variable_names.accumulator, 2293 num_hosts=num_hosts, 2294 vocabulary_size=table_config.vocabulary_size, 2295 embedding_dimension=table_config.dimension, 2296 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2297 initializer=accumulator_initializer) 2298 linear_initializer = init_ops.constant_initializer( 2299 self._optimization_parameters.initial_linear_value) 2300 linear_variables = _create_partitioned_variables( 2301 name=slot_variable_names.linear, 2302 num_hosts=num_hosts, 2303 vocabulary_size=table_config.vocabulary_size, 2304 embedding_dimension=table_config.dimension, 2305 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2306 initializer=linear_initializer) 2307 slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables) 2308 2309 def load_ops_fn(): 2310 """Returns the retrieve ops for Ftrl embedding tables. 2311 2312 Returns: 2313 A list of ops to load embedding and slot variables from CPU to TPU. 2314 """ 2315 config = config_proto 2316 load_op_list = [] 2317 for host_id, table_variable, accumulator_variable, linear_variable in zip( 2318 range(num_hosts), table_variables, accumulator_variables, 2319 linear_variables): 2320 with ops.colocate_with(table_variable): 2321 load_parameters_op = ( 2322 tpu_ops.load_tpu_embedding_ftrl_parameters( 2323 parameters=table_variable, 2324 accumulators=accumulator_variable, 2325 linears=linear_variable, 2326 table_name=table, 2327 num_shards=num_hosts, 2328 shard_id=host_id, 2329 config=config)) 2330 config = None 2331 load_op_list.append(load_parameters_op) 2332 return load_op_list 2333 2334 def retrieve_ops_fn(): 2335 """Returns the retrieve ops for Ftrl embedding tables. 2336 2337 Returns: 2338 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2339 """ 2340 config = config_proto 2341 retrieve_op_list = [] 2342 for host_id, table_variable, accumulator_variable, linear_variable in zip( 2343 range(num_hosts), table_variables, accumulator_variables, 2344 linear_variables): 2345 with ops.colocate_with(table_variable): 2346 retrieved_table, retrieved_accumulator, retrieved_linear = ( 2347 tpu_ops.retrieve_tpu_embedding_ftrl_parameters( 2348 table_name=table, 2349 num_shards=num_hosts, 2350 shard_id=host_id, 2351 config=config)) 2352 retrieve_parameters_op = control_flow_ops.group( 2353 state_ops.assign(table_variable, retrieved_table), 2354 state_ops.assign(accumulator_variable, retrieved_accumulator), 2355 state_ops.assign(linear_variable, retrieved_linear)) 2356 config = None 2357 retrieve_op_list.append(retrieve_parameters_op) 2358 return retrieve_op_list 2359 2360 return slot_variables, load_ops_fn, retrieve_ops_fn 2361 2362 2363class _ProximalYogiHandler(_OptimizerHandler): 2364 """Handles Proximal Yogi specific logic.""" 2365 2366 def set_optimization_parameters(self, table_descriptor): 2367 table_descriptor.optimization_parameters.proximal_yogi.SetInParent() 2368 table_descriptor.optimization_parameters.proximal_yogi.beta1 = ( 2369 self._optimization_parameters.beta1) 2370 table_descriptor.optimization_parameters.proximal_yogi.beta2 = ( 2371 self._optimization_parameters.beta2) 2372 table_descriptor.optimization_parameters.proximal_yogi.epsilon = ( 2373 self._optimization_parameters.epsilon) 2374 table_descriptor.optimization_parameters.proximal_yogi.l1 = ( 2375 self._optimization_parameters.l1_regularization_strength) 2376 table_descriptor.optimization_parameters.proximal_yogi.l2 = ( 2377 self._optimization_parameters.l2_regularization_strength) 2378 2379 def get_default_slot_variable_names(self, table): 2380 return ProximalYogiSlotVariableNames( 2381 '{}/{}'.format(table, 'ProximalYogi'), # v 2382 '{}/{}_1'.format(table, 'ProximalYogi')) # m 2383 2384 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2385 table_config, table_variables, config_proto): 2386 v_initializer = init_ops.constant_initializer( 2387 self._optimization_parameters.initial_accumulator_value) 2388 v_variables = _create_partitioned_variables( 2389 name=slot_variable_names.v, 2390 num_hosts=num_hosts, 2391 vocabulary_size=table_config.vocabulary_size, 2392 embedding_dimension=table_config.dimension, 2393 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2394 initializer=v_initializer) 2395 m_initializer = init_ops.zeros_initializer() 2396 m_variables = _create_partitioned_variables( 2397 name=slot_variable_names.m, 2398 num_hosts=num_hosts, 2399 vocabulary_size=table_config.vocabulary_size, 2400 embedding_dimension=table_config.dimension, 2401 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2402 initializer=m_initializer) 2403 slot_variables = ProximalYogiSlotVariables(v_variables, m_variables) 2404 2405 def load_ops_fn(): 2406 """Returns the load ops for Proximal Yogi embedding tables. 2407 2408 Returns: 2409 A list of ops to load embedding and slot variables from CPU to TPU. 2410 """ 2411 load_op_list = [] 2412 config = config_proto 2413 for host_id, table_variable, v_variable, m_variable in (zip( 2414 range(num_hosts), table_variables, v_variables, m_variables)): 2415 with ops.colocate_with(table_variable): 2416 load_parameters_op = ( 2417 tpu_ops.load_tpu_embedding_proximal_yogi_parameters( 2418 parameters=table_variable, 2419 v=v_variable, 2420 m=m_variable, 2421 table_name=table, 2422 num_shards=num_hosts, 2423 shard_id=host_id, 2424 config=config)) 2425 # Set config to None to enforce that config is only loaded to the first 2426 # table. 2427 config = None 2428 load_op_list.append(load_parameters_op) 2429 return load_op_list 2430 2431 def retrieve_ops_fn(): 2432 """Returns the retrieve ops for Proximal Yogi embedding tables. 2433 2434 Returns: 2435 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2436 """ 2437 retrieve_op_list = [] 2438 config = config_proto 2439 for host_id, table_variable, v_variable, m_variable in (zip( 2440 range(num_hosts), table_variables, v_variables, m_variables)): 2441 with ops.colocate_with(table_variable): 2442 retrieved_table, retrieved_v, retrieved_m = ( 2443 tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters( 2444 table_name=table, 2445 num_shards=num_hosts, 2446 shard_id=host_id, 2447 config=config)) 2448 retrieve_parameters_op = control_flow_ops.group( 2449 state_ops.assign(table_variable, retrieved_table), 2450 state_ops.assign(v_variable, retrieved_v), 2451 state_ops.assign(m_variable, retrieved_m)) 2452 config = None 2453 retrieve_op_list.append(retrieve_parameters_op) 2454 return retrieve_op_list 2455 2456 return slot_variables, load_ops_fn, retrieve_ops_fn 2457 2458 2459class _MomentumHandler(_OptimizerHandler): 2460 """Handles Momentum specific logic.""" 2461 2462 def set_optimization_parameters(self, table_descriptor): 2463 (table_descriptor.optimization_parameters.momentum.SetInParent()) 2464 table_descriptor.optimization_parameters.momentum.momentum = ( 2465 self._optimization_parameters.momentum) 2466 table_descriptor.optimization_parameters.momentum.use_nesterov = ( 2467 self._optimization_parameters.use_nesterov) 2468 2469 def get_default_slot_variable_names(self, table): 2470 return MomentumSlotVariableName('{}/{}'.format(table, 'Momentum')) 2471 2472 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2473 table_config, table_variables, config_proto): 2474 2475 momenta_initializer = init_ops.zeros_initializer() 2476 momenta_variables = _create_partitioned_variables( 2477 name=slot_variable_names.momenta, 2478 num_hosts=num_hosts, 2479 vocabulary_size=table_config.vocabulary_size, 2480 embedding_dimension=table_config.dimension, 2481 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2482 initializer=momenta_initializer) 2483 slot_variables = MomentumSlotVariable(momenta_variables) 2484 2485 def load_ops_fn(): 2486 """Returns the retrieve ops for Momentum embedding tables. 2487 2488 Returns: 2489 A list of ops to load embedding and slot variables from CPU to TPU. 2490 """ 2491 load_op_list = [] 2492 config = config_proto 2493 for host_id, table_variable, momenta_variable in (zip( 2494 range(num_hosts), table_variables, momenta_variables)): 2495 with ops.colocate_with(table_variable): 2496 load_parameters_op = tpu_ops.load_tpu_embedding_momentum_parameters( 2497 parameters=table_variable, 2498 momenta=momenta_variable, 2499 table_name=table, 2500 num_shards=num_hosts, 2501 shard_id=host_id, 2502 config=config, 2503 ) 2504 config = None 2505 load_op_list.append(load_parameters_op) 2506 return load_op_list 2507 2508 def retrieve_ops_fn(): 2509 """Returns the retrieve ops for Momentum embedding tables. 2510 2511 Returns: 2512 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2513 """ 2514 retrieve_op_list = [] 2515 config = config_proto 2516 for host_id, table_variable, momenta_variable in (zip( 2517 range(num_hosts), table_variables, momenta_variables)): 2518 with ops.colocate_with(table_variable): 2519 retrieved_table, retrieved_momenta = ( 2520 tpu_ops.retrieve_tpu_embedding_momentum_parameters( 2521 table_name=table, 2522 num_shards=num_hosts, 2523 shard_id=host_id, 2524 config=config, 2525 )) 2526 retrieve_parameters_op = control_flow_ops.group( 2527 state_ops.assign(table_variable, retrieved_table), 2528 state_ops.assign(momenta_variable, retrieved_momenta)) 2529 config = None 2530 retrieve_op_list.append(retrieve_parameters_op) 2531 return retrieve_op_list 2532 2533 return slot_variables, load_ops_fn, retrieve_ops_fn 2534 2535 2536class _RMSPropHandler(_OptimizerHandler): 2537 """Handles RMS prop specific logic.""" 2538 2539 def set_optimization_parameters(self, table_descriptor): 2540 (table_descriptor.optimization_parameters.rms_prop.SetInParent()) 2541 table_descriptor.optimization_parameters.rms_prop.rho = ( 2542 self._optimization_parameters.rho) 2543 table_descriptor.optimization_parameters.rms_prop.epsilon = ( 2544 self._optimization_parameters.epsilon) 2545 table_descriptor.optimization_parameters.rms_prop.momentum = ( 2546 self._optimization_parameters.momentum) 2547 2548 def get_default_slot_variable_names(self, table): 2549 return RMSPropSlotVariableNames('{}/{}/ms'.format(table, 'RMSProp'), 2550 '{}/{}/mom'.format(table, 'RMSProp')) 2551 2552 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2553 table_config, table_variables, config_proto): 2554 2555 ms_variables = _create_partitioned_variables( 2556 name=slot_variable_names.ms, 2557 num_hosts=num_hosts, 2558 vocabulary_size=table_config.vocabulary_size, 2559 embedding_dimension=table_config.dimension, 2560 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2561 initializer=init_ops.zeros_initializer(), 2562 ) 2563 mom_variables = _create_partitioned_variables( 2564 name=slot_variable_names.mom, 2565 num_hosts=num_hosts, 2566 vocabulary_size=table_config.vocabulary_size, 2567 embedding_dimension=table_config.dimension, 2568 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2569 initializer=init_ops.zeros_initializer(), 2570 ) 2571 slot_variables = RMSPropSlotVariables(ms_variables, mom_variables) 2572 2573 def load_ops_fn(): 2574 """Returns the retrieve ops for RMS Prop embedding tables. 2575 2576 Returns: 2577 A list of ops to load embedding and slot variables from CPU to TPU. 2578 """ 2579 load_op_list = [] 2580 config = config_proto 2581 for host_id, table_variable, ms_variable, mom_variable in (zip( 2582 range(num_hosts), table_variables, ms_variables, mom_variables)): 2583 with ops.colocate_with(table_variable): 2584 load_parameters_op = tpu_ops.load_tpu_embedding_rms_prop_parameters( 2585 parameters=table_variable, 2586 ms=ms_variable, 2587 mom=mom_variable, 2588 table_name=table, 2589 num_shards=num_hosts, 2590 shard_id=host_id, 2591 config=config, 2592 ) 2593 config = None 2594 load_op_list.append(load_parameters_op) 2595 return load_op_list 2596 2597 def retrieve_ops_fn(): 2598 """Returns the retrieve ops for RMS Prop embedding tables. 2599 2600 Returns: 2601 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2602 """ 2603 retrieve_op_list = [] 2604 config = config_proto 2605 for host_id, table_variable, ms_variable, mom_variable in (zip( 2606 range(num_hosts), table_variables, ms_variables, mom_variables)): 2607 with ops.colocate_with(table_variable): 2608 retrieved_table, retrieved_ms, retrieved_mom = ( 2609 tpu_ops.retrieve_tpu_embedding_rms_prop_parameters( 2610 table_name=table, 2611 num_shards=num_hosts, 2612 shard_id=host_id, 2613 config=config, 2614 )) 2615 retrieve_parameters_op = control_flow_ops.group( 2616 state_ops.assign(table_variable, retrieved_table), 2617 state_ops.assign(ms_variable, retrieved_ms), 2618 state_ops.assign(mom_variable, retrieved_mom)) 2619 config = None 2620 retrieve_op_list.append(retrieve_parameters_op) 2621 return retrieve_op_list 2622 2623 return slot_variables, load_ops_fn, retrieve_ops_fn 2624 2625 2626class _FrequencyEstimatorHandler(_OptimizerHandler): 2627 """Handles frequency estimator specific logic.""" 2628 2629 def set_optimization_parameters(self, table_descriptor): 2630 table_descriptor.optimization_parameters.frequency_estimator.SetInParent() 2631 freq = table_descriptor.optimization_parameters.frequency_estimator 2632 freq.tau = self._optimization_parameters.tau 2633 freq.max_delta = self._optimization_parameters.max_delta 2634 freq.outlier_threshold = self._optimization_parameters.outlier_threshold 2635 freq.weight_exponent = self._optimization_parameters.weight_exponent 2636 2637 def get_default_slot_variable_names(self, table): 2638 return FrequencyEstimatorSlotVariableName( 2639 '{}/FrequencyEstimator'.format(table)) 2640 2641 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2642 table_config, table_variables, config_proto): 2643 if table_config.dimension != 1: 2644 raise ValueError('FrequencyEstimator tables should only have a dimension ' 2645 'of 1. Received dimension {}'.format( 2646 table_config.dimension)) 2647 2648 last_hit_step_variables = _create_partitioned_variables( 2649 name=slot_variable_names.last_hit_step, 2650 num_hosts=num_hosts, 2651 vocabulary_size=table_config.vocabulary_size, 2652 embedding_dimension=table_config.dimension, 2653 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2654 initializer=init_ops.zeros_initializer(), 2655 ) 2656 slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables) 2657 2658 def load_ops_fn(): 2659 """Returns the retrieve ops for Frequency Estimator embedding tables. 2660 2661 Returns: 2662 A list of ops to load embedding and slot variables from CPU to TPU. 2663 """ 2664 load_op_list = [] 2665 config = config_proto 2666 for host_id, table_variable, last_hit_step_variable in (zip( 2667 range(num_hosts), table_variables, last_hit_step_variables)): 2668 with ops.colocate_with(table_variable): 2669 load_parameters_op = ( 2670 tpu_ops.load_tpu_embedding_frequency_estimator_parameters( 2671 parameters=table_variable, 2672 last_hit_step=last_hit_step_variable, 2673 table_name=table, 2674 num_shards=num_hosts, 2675 shard_id=host_id, 2676 config=config)) 2677 config = None 2678 load_op_list.append(load_parameters_op) 2679 return load_op_list 2680 2681 def retrieve_ops_fn(): 2682 """Returns the retrieve ops for Frequency Estimator embedding tables. 2683 2684 Returns: 2685 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2686 """ 2687 retrieve_op_list = [] 2688 config = config_proto 2689 for host_id, table_variable, last_hit_step_variable in (zip( 2690 range(num_hosts), table_variables, last_hit_step_variables)): 2691 with ops.colocate_with(table_variable): 2692 retrieved_table, retrieved_last_hit_step = ( 2693 tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters( 2694 table_name=table, 2695 num_shards=num_hosts, 2696 shard_id=host_id, 2697 config=config, 2698 )) 2699 retrieve_parameters_op = control_flow_ops.group( 2700 state_ops.assign(table_variable, retrieved_table), 2701 state_ops.assign(last_hit_step_variable, retrieved_last_hit_step)) 2702 config = None 2703 retrieve_op_list.append(retrieve_parameters_op) 2704 return retrieve_op_list 2705 2706 return slot_variables, load_ops_fn, retrieve_ops_fn 2707 2708 2709class _StochasticGradientDescentHandler(_OptimizerHandler): 2710 """Handles stochastic gradient descent specific logic.""" 2711 2712 def set_optimization_parameters(self, table_descriptor): 2713 (table_descriptor.optimization_parameters.stochastic_gradient_descent 2714 .SetInParent()) 2715 2716 def get_default_slot_variable_names(self, table): 2717 return None 2718 2719 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2720 table_config, table_variables, config_proto): 2721 del table_config 2722 2723 def load_ops_fn(): 2724 """Returns the retrieve ops for AdaGrad embedding tables. 2725 2726 Returns: 2727 A list of ops to load embedding and slot variables from CPU to TPU. 2728 """ 2729 load_op_list = [] 2730 config = config_proto 2731 for host_id, table_variable in enumerate(table_variables): 2732 with ops.colocate_with(table_variable): 2733 load_parameters_op = ( 2734 tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters( 2735 parameters=table_variable, 2736 table_name=table, 2737 num_shards=num_hosts, 2738 shard_id=host_id, 2739 config=config)) 2740 config = None 2741 load_op_list.append(load_parameters_op) 2742 return load_op_list 2743 2744 def retrieve_ops_fn(): 2745 """Returns the retrieve ops for SGD embedding tables. 2746 2747 Returns: 2748 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2749 """ 2750 retrieve_op_list = [] 2751 config = config_proto 2752 for host_id, table_variable in enumerate(table_variables): 2753 with ops.colocate_with(table_variable): 2754 retrieved_table = ( 2755 tpu_ops 2756 .retrieve_tpu_embedding_stochastic_gradient_descent_parameters( 2757 table_name=table, 2758 num_shards=num_hosts, 2759 shard_id=host_id, 2760 config=config)) 2761 retrieve_parameters_op = control_flow_ops.group( 2762 state_ops.assign(table_variable, retrieved_table)) 2763 config = None 2764 retrieve_op_list.append(retrieve_parameters_op) 2765 return retrieve_op_list 2766 2767 return None, load_ops_fn, retrieve_ops_fn 2768 2769 2770def _get_optimization_handler(optimization_parameters): 2771 """Gets the optimization handler given the parameter type.""" 2772 if isinstance(optimization_parameters, AdagradParameters): 2773 return _AdagradHandler(optimization_parameters) 2774 elif isinstance(optimization_parameters, ProximalAdagradParameters): 2775 return _ProximalAdagradHandler(optimization_parameters) 2776 elif isinstance(optimization_parameters, AdamParameters): 2777 return _AdamHandler(optimization_parameters) 2778 elif isinstance(optimization_parameters, FtrlParameters): 2779 return _FtrlHandler(optimization_parameters) 2780 elif isinstance(optimization_parameters, ProximalYogiParameters): 2781 return _ProximalYogiHandler(optimization_parameters) 2782 elif isinstance(optimization_parameters, StochasticGradientDescentParameters): 2783 return _StochasticGradientDescentHandler(optimization_parameters) 2784 elif isinstance(optimization_parameters, MomentumParameters): 2785 return _MomentumHandler(optimization_parameters) 2786 elif isinstance(optimization_parameters, RMSPropParameters): 2787 return _RMSPropHandler(optimization_parameters) 2788 elif isinstance(optimization_parameters, FrequencyEstimatorParameters): 2789 return _FrequencyEstimatorHandler(optimization_parameters) 2790 return NotImplementedError() 2791 2792 2793def _create_ordered_dict(d): 2794 """Create an OrderedDict from Dict.""" 2795 return collections.OrderedDict((k, d[k]) for k in sorted(d)) 2796 2797 2798def _create_combiners(table_to_config_dict, table_to_features_dict): 2799 """Create a per feature list of combiners, ordered by table.""" 2800 combiners = [] 2801 for table in table_to_config_dict: 2802 combiner = table_to_config_dict[table].combiner or 'sum' 2803 combiners.extend([combiner] * len(table_to_features_dict[table])) 2804 return combiners 2805 2806 2807def _create_table_to_features_and_num_features_dicts(feature_to_config_dict): 2808 """Create mapping from table to a list of its features.""" 2809 table_to_features_dict_tmp = {} 2810 table_to_num_features_dict_tmp = {} 2811 for feature, feature_config in six.iteritems(feature_to_config_dict): 2812 if feature_config.table_id in table_to_features_dict_tmp: 2813 table_to_features_dict_tmp[feature_config.table_id].append(feature) 2814 else: 2815 table_to_features_dict_tmp[feature_config.table_id] = [feature] 2816 table_to_num_features_dict_tmp[feature_config.table_id] = 0 2817 if feature_config.max_sequence_length == 0: 2818 table_to_num_features_dict_tmp[feature_config.table_id] = ( 2819 table_to_num_features_dict_tmp[feature_config.table_id] + 1) 2820 else: 2821 table_to_num_features_dict_tmp[feature_config.table_id] = ( 2822 table_to_num_features_dict_tmp[feature_config.table_id] + 2823 feature_config.max_sequence_length) 2824 2825 table_to_features_dict = collections.OrderedDict() 2826 table_to_num_features_dict = collections.OrderedDict() 2827 for table in sorted(table_to_features_dict_tmp): 2828 table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) 2829 table_to_num_features_dict[table] = table_to_num_features_dict_tmp[table] 2830 return table_to_features_dict, table_to_num_features_dict 2831 2832 2833def _create_device_fn(hosts): 2834 """Create device_fn() to use with _create_partitioned_variables().""" 2835 2836 def device_fn(op): 2837 """Returns the `device` for `op`.""" 2838 part_match = re.match(r'.*/part_(\d+)(/|$)', op.name) 2839 dummy_match = re.match(r'.*dummy_(\d+).*', op.name) 2840 if not part_match and not dummy_match: 2841 raise RuntimeError( 2842 'Internal Error: Expected {} to contain /part_* or dummy_*'.format( 2843 op.name)) 2844 2845 if part_match: 2846 idx = int(part_match.group(1)) 2847 else: 2848 idx = int(dummy_match.group(1)) # pytype: disable=attribute-error 2849 2850 device = hosts[idx] 2851 logging.debug('assigning {} to {}.', op, device) 2852 return device 2853 2854 return device_fn 2855 2856 2857def _create_partitioned_variables(name, 2858 num_hosts, 2859 vocabulary_size, 2860 embedding_dimension, 2861 initializer, 2862 collections=None): # pylint: disable=redefined-outer-name 2863 """Creates PartitionedVariables based on `num_hosts` for `table`.""" 2864 2865 num_slices = min(vocabulary_size, num_hosts) 2866 2867 var_list = list( 2868 variable_scope.get_variable( 2869 name, 2870 shape=(vocabulary_size, embedding_dimension), 2871 partitioner=partitioned_variables.fixed_size_partitioner(num_slices), 2872 dtype=dtypes.float32, 2873 initializer=initializer, 2874 collections=collections, 2875 trainable=False)) 2876 2877 if vocabulary_size >= num_hosts: 2878 return var_list 2879 2880 # For padded part, define the dummy variable to be loaded into TPU system. 2881 for idx in range(num_hosts - vocabulary_size): 2882 var_list.append( 2883 variable_scope.get_variable( 2884 'dummy_{}_{}'.format(vocabulary_size + idx, name), 2885 shape=(1, embedding_dimension), 2886 dtype=dtypes.float32, 2887 initializer=initializer, 2888 collections=[ops.GraphKeys.LOCAL_VARIABLES], 2889 trainable=False)) 2890 2891 return var_list 2892