1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU embedding APIs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import math
24import re
25
26from typing import Optional
27
28import six
29
30from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
31from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
32from tensorflow.python.eager import context
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import init_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import partitioned_variables
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
44from tensorflow.python.tpu.ops import tpu_ops
45from tensorflow.python.util.tf_export import tf_export
46
47TRAINING = elc.TPUEmbeddingConfiguration.TRAINING
48INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
49
50
51# TODO(shizhiw): a more future-proof way is to have optimization_parameter such
52#  as AdagradParameters etc instead of learning_rate.
53class TableConfig(
54    collections.namedtuple('TableConfig', [
55        'vocabulary_size',
56        'dimension',
57        'initializer',
58        'combiner',
59        'hot_id_replication',
60        'learning_rate',
61        'learning_rate_fn',
62        'optimization_parameters',
63    ])):
64  """Embedding table configuration."""
65
66  def __new__(cls,
67              vocabulary_size,
68              dimension,
69              initializer=None,
70              combiner='mean',
71              hot_id_replication=False,
72              learning_rate=None,
73              learning_rate_fn=None,
74              optimization_parameters=None):
75    """Embedding table configuration.
76
77    Args:
78      vocabulary_size: Number of vocabulary (/rows) in the table.
79      dimension: The embedding dimension.
80      initializer: A variable initializer function to be used in embedding
81        variable initialization. If not specified, defaults to
82        `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard
83        deviation `1/sqrt(dimension)`.
84      combiner: A string specifying how to reduce if there are multiple entries
85        in a single row. Currently 'mean', 'sqrtn', 'sum' and None are
86        supported, with 'mean' the default. 'sqrtn' often achieves good
87        accuracy, in particular with bag-of-words columns. For more information,
88        see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather
89        than sparse tensors.
90      hot_id_replication: If true, enables hot id replication, which can make
91        embedding lookups faster if there are some hot rows in the table.
92      learning_rate: float, static learning rate for this table. If
93        learning_rate and learning_rate_fn are both `None`, static learning rate
94        as specified in local `optimization_parameters` will be used. In case
95        local `optimization_parameters` is `None`, global
96        `optimization_parameters` in `TPUEmbedding` constructor will be used.
97        `learning_rate_fn` must be `None` if `learning_rate` is not `None.
98      learning_rate_fn: string, use dynamic learning rate given by the function.
99        This function will be passed the current global step. If learning_rate
100        and learning_rate_fn are both `None`, static learning rate as specified
101        in `optimization_parameters` is used. `learning_rate` must be `None` if
102        `learning_rate_fn` is not `None.
103      optimization_parameters: `AdagradParameters`, `AdamParameters`,
104        `Stochasticgradientdescentparameters`. Specifies table level optimizer.
105        If it's `None` global optimizer in `TPUEmbedding` constructor is used.
106
107    Returns:
108      `TableConfig`.
109
110    Raises:
111      ValueError: if `vocabulary_size` is not positive integer.
112      ValueError: if `dimension` is not positive integer.
113      ValueError: if `initializer` is specified and is not callable.
114      ValueError: if `combiner` is not supported.
115      ValueError: if `learning_rate` and `learning_rate_fn` are both not
116        `None`.
117    """
118    if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
119      raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size))
120
121    if not isinstance(dimension, int) or dimension < 1:
122      raise ValueError('Invalid dimension {}.'.format(dimension))
123
124    if (initializer is not None) and (not callable(initializer)):
125      raise ValueError('initializer must be callable if specified.')
126    if initializer is None:
127      initializer = init_ops.truncated_normal_initializer(
128          mean=0.0, stddev=1 / math.sqrt(dimension))
129
130    if combiner not in ('mean', 'sum', 'sqrtn', None):
131      raise ValueError('Invalid combiner {}'.format(combiner))
132
133    if learning_rate is not None and learning_rate_fn is not None:
134      raise ValueError('At most one of learning_rate and learning_rate_fn '
135                       'can be None; got {} and {}'.format(
136                           learning_rate, learning_rate_fn))
137
138    if optimization_parameters is not None:
139      if not isinstance(optimization_parameters, _OptimizationParameters):
140        raise ValueError('`optimization_parameters` must inherit from '
141                         '`_OptimizationParameters`. '
142                         '`type(optimization_parameters)`={}'.format(
143                             type(optimization_parameters)))
144
145    return super(TableConfig,
146                 cls).__new__(cls, vocabulary_size, dimension, initializer,
147                              combiner, hot_id_replication, learning_rate,
148                              learning_rate_fn, optimization_parameters)
149
150
151class FeatureConfig(
152    collections.namedtuple('FeatureConfig',
153                           ['table_id', 'max_sequence_length', 'weight_key'])):
154  """Feature configuration."""
155
156  def __new__(cls, table_id, max_sequence_length=0, weight_key=None):
157    """Feature configuration.
158
159    Args:
160      table_id: Which table the feature is uses for embedding lookups.
161      max_sequence_length: If positive, the feature is a sequence feature with
162        the corresponding maximum sequence length. If the sequence is longer
163        than this, it will be truncated. If 0, the feature is not a sequence
164        feature.
165      weight_key: If using weights for the combiner, this key specifies which
166        input feature contains the weights.
167
168    Returns:
169      `FeatureConfig`.
170
171    Raises:
172      ValueError: if `max_sequence_length` non-negative.
173    """
174    if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
175      raise ValueError(
176          'Invalid max_sequence_length {}.'.format(max_sequence_length))
177
178    return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length,
179                                             weight_key)
180
181
182class EnqueueData(
183    collections.namedtuple(
184        'EnqueueData',
185        ['embedding_indices', 'sample_indices', 'aggregation_weights'])):
186  """Data to be enqueued through generate_enqueue_ops()."""
187
188  def __new__(cls,
189              embedding_indices,
190              sample_indices=None,
191              aggregation_weights=None):
192    """Data to be enqueued through generate_enqueue_ops().
193
194    Args:
195      embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
196        corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
197        and int64 are allowed and will be converted to int32 internally.
198      sample_indices: A rank 2 Tensor specifying the training example to which
199        the corresponding embedding_indices and aggregation_weights values
200        belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
201        If it is None, we assume each embedding_indices belongs to a different
202        sample. Both int32 and int64 are allowed and will be converted to int32
203        internally.
204      aggregation_weights: A rank 1 Tensor containing aggregation weights. It
205        corresponds to sp_weights.values in embedding_lookup_sparse(). If it is
206        None, we assume all weights are 1. Both float32 and float64 are allowed
207        and will be converted to float32 internally.
208
209    Returns:
210      An EnqueueData tuple.
211
212    """
213    return super(EnqueueData, cls).__new__(cls, embedding_indices,
214                                           sample_indices, aggregation_weights)
215
216  @staticmethod
217  def from_sparse_tensor(sp_tensor, weights=None):
218    return EnqueueData(
219        sp_tensor.values,
220        sp_tensor.indices,
221        aggregation_weights=weights.values if weights is not None else None)
222
223
224class RaggedEnqueueData(
225    collections.namedtuple(
226        'RaggedEnqueueData',
227        ['embedding_indices', 'sample_splits', 'aggregation_weights'])):
228  """RaggedTensor Data to be enqueued through generate_enqueue_ops()."""
229
230  def __new__(cls,
231              embedding_indices,
232              sample_splits=None,
233              aggregation_weights=None):
234    """Data to be enqueued through generate_enqueue_ops().
235
236    Args:
237      embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
238        corresponds to ids.values in embedding_lookup(), when ids is a
239        RaggedTensor. Both int32 and int64 are allowed and will be converted to
240        int32 internally.
241      sample_splits: A rank 1 Tensor specifying the break points for splitting
242        embedding_indices and aggregation_weights into rows. It corresponds to
243        ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both
244        int32 and int64 are allowed and will be converted to int32 internally.
245      aggregation_weights: A rank 1 Tensor containing per training example
246        aggregation weights. It corresponds to the values field of a
247        RaggedTensor with the same row_splits as ids in embedding_lookup(), when
248        ids is a RaggedTensor.
249
250    Returns:
251      An RaggedEnqueueData tuple.
252
253    """
254    return super(RaggedEnqueueData,
255                 cls).__new__(cls, embedding_indices, sample_splits,
256                              aggregation_weights)
257
258  @staticmethod
259  def from_ragged_tensor(rg_tensor, weights=None):
260    return RaggedEnqueueData(
261        rg_tensor.values,
262        rg_tensor.row_splits,
263        aggregation_weights=weights.values if weights is not None else None)
264
265
266def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
267  """Convenient function for generate_enqueue_ops().
268
269  Args:
270    sp_tensors_list: a list of dictionary mapping from string of feature names
271      to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the
272      same host should be contiguous on the list.
273
274  Returns:
275    enqueue_datas_list: a list of dictionary mapping from string
276      of feature names to EnqueueData. Each dictionary is for one
277      TPU core. Dictionaries for the same host should be contiguous
278      on the list.
279
280  """
281  enqueue_datas_list = []
282  for sp_tensors in sp_tensors_list:
283    enqueue_datas = collections.OrderedDict(
284        (k, EnqueueData.from_sparse_tensor(v))
285        for k, v in six.iteritems(sp_tensors))
286    enqueue_datas_list.append(enqueue_datas)
287  return enqueue_datas_list
288
289
290def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
291  """Convenient function for generate_enqueue_ops().
292
293  Args:
294    rg_tensors_list: a list of dictionary mapping from string of feature names
295      to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the
296      same host should be contiguous on the list.
297
298  Returns:
299    enqueue_datas_list: a list of dictionary mapping from string
300      of feature names to RaggedEnqueueData. Each dictionary is for one
301      TPU core. Dictionaries for the same host should be contiguous
302      on the list.
303
304  """
305  enqueue_datas_list = []
306  for rg_tensors in rg_tensors_list:
307    enqueue_datas = collections.OrderedDict(
308        (k, RaggedEnqueueData.from_ragged_tensor(v))
309        for k, v in six.iteritems(rg_tensors))
310    enqueue_datas_list.append(enqueue_datas)
311  return enqueue_datas_list
312
313
314AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames',
315                                               ['m', 'v'])
316
317AdagradSlotVariableName = collections.namedtuple('AdagradSlotVariableName',
318                                                 ['accumulator'])
319
320MomentumSlotVariableName = collections.namedtuple('MomentumSlotVariableName',
321                                                  ['momenta'])
322
323RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames',
324                                                  ['ms', 'mom'])
325
326ProximalAdagradSlotVariableName = collections.namedtuple(
327    'ProximalAdagradSlotVariableName', ['accumulator'])
328
329FtrlSlotVariableName = collections.namedtuple('FtrlSlotVariableName',
330                                              ['accumulator', 'linear'])
331
332ProximalYogiSlotVariableNames = collections.namedtuple(
333    'ProximalYogiSlotVariableNames', ['v', 'm'])
334
335FrequencyEstimatorSlotVariableName = collections.namedtuple(
336    'FrequencyEstimatorSlotVariableName', ['last_hit_step'])
337
338AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v'])
339
340MomentumSlotVariable = collections.namedtuple('MomentumSlotVariable',
341                                              ['momenta'])
342
343RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables',
344                                              ['ms', 'mom'])
345
346AdagradSlotVariable = collections.namedtuple('AdagradSlotVariable',
347                                             ['accumulator'])
348
349ProximalAdagradSlotVariable = collections.namedtuple(
350    'ProximalAdagradSlotVariable', ['accumulator'])
351
352FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable',
353                                          ['accumulator', 'linear'])
354
355ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
356                                                   ['v', 'm'])
357
358FrequencyEstimatorSlotVariables = collections.namedtuple(
359    'FrequencyEstimatorSlotVariables', ['last_hit_step'])
360
361VariablesAndOps = collections.namedtuple('VariablesAndOps', [
362    'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops',
363    'retrieve_ops'
364])
365
366
367class _OptimizationParameters(object):
368  """Parameters common to all optimizations."""
369
370  def __init__(
371      self,
372      learning_rate: float,
373      use_gradient_accumulation: bool,
374      clip_weight_min: Optional[float],
375      clip_weight_max: Optional[float],
376      weight_decay_factor: Optional[float],
377      multiply_weight_decay_factor_by_learning_rate: Optional[bool],
378      clip_gradient_min: Optional[float] = None,
379      clip_gradient_max: Optional[float] = None,
380  ):
381    self.learning_rate = learning_rate
382    self.use_gradient_accumulation = use_gradient_accumulation
383    self.clip_weight_min = clip_weight_min
384    self.clip_weight_max = clip_weight_max
385    self.weight_decay_factor = weight_decay_factor
386    self.multiply_weight_decay_factor_by_learning_rate = (
387        multiply_weight_decay_factor_by_learning_rate)
388    self.clip_gradient_min = clip_gradient_min
389    self.clip_gradient_max = clip_gradient_max
390
391    if not use_gradient_accumulation and (clip_gradient_min is not None or
392                                          clip_gradient_max is not None):
393      ValueError('When using gradient clipping limits, gradient accumulation '
394                 'must be enabled.')
395
396
397@tf_export(v1=['tpu.experimental.AdagradParameters'])
398class AdagradParameters(_OptimizationParameters):
399  """Optimization parameters for Adagrad with TPU embeddings.
400
401  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
402  `optimization_parameters` argument to set the optimizer and its parameters.
403  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
404  for more details.
405
406  ```
407  estimator = tf.estimator.tpu.TPUEstimator(
408      ...
409      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
410          ...
411          optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1),
412          ...))
413  ```
414
415  """
416
417  def __init__(
418      self,
419      learning_rate: float,
420      initial_accumulator: float = 0.1,
421      use_gradient_accumulation: bool = True,
422      clip_weight_min: Optional[float] = None,
423      clip_weight_max: Optional[float] = None,
424      weight_decay_factor: Optional[float] = None,
425      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
426      clip_gradient_min: Optional[float] = None,
427      clip_gradient_max: Optional[float] = None,
428  ):
429    """Optimization parameters for Adagrad.
430
431    Args:
432      learning_rate: used for updating embedding table.
433      initial_accumulator: initial accumulator for Adagrad.
434      use_gradient_accumulation: setting this to `False` makes embedding
435        gradients calculation less accurate but faster. Please see
436        `optimization_parameters.proto` for details.
437      clip_weight_min: the minimum value to clip by; None means -infinity.
438      clip_weight_max: the maximum value to clip by; None means +infinity.
439      weight_decay_factor: amount of weight decay to apply; None means that the
440        weights are not decayed.
441      multiply_weight_decay_factor_by_learning_rate: if true,
442        `weight_decay_factor` is multiplied by the current learning rate.
443      clip_gradient_min: the minimum value to clip by; None means -infinity.
444        Gradient accumulation must be set to true if this is set.
445      clip_gradient_max: the maximum value to clip by; None means +infinity.
446        Gradient accumulation must be set to true if this is set.
447    """
448    super(AdagradParameters, self).__init__(
449        learning_rate=learning_rate,
450        use_gradient_accumulation=use_gradient_accumulation,
451        clip_weight_min=clip_weight_min,
452        clip_weight_max=clip_weight_max,
453        weight_decay_factor=weight_decay_factor,
454        multiply_weight_decay_factor_by_learning_rate=(
455            multiply_weight_decay_factor_by_learning_rate),
456        clip_gradient_min=clip_gradient_min,
457        clip_gradient_max=clip_gradient_max,
458    )
459    if initial_accumulator <= 0:
460      raise ValueError('Adagrad initial_accumulator must be positive')
461    self.initial_accumulator = initial_accumulator
462
463
464class ProximalAdagradParameters(_OptimizationParameters):
465  """Optimization parameters for ProximalAdagrad with TPU embeddings.
466
467  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
468  `optimization_parameters` argument to set the optimizer and its parameters.
469  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
470  for more details.
471  """
472
473  def __init__(
474      self,
475      learning_rate: float,
476      initial_accumulator: float = 0.1,
477      l1_regularization_strength: float = 0.0,
478      l2_regularization_strength: float = 0.0,
479      use_gradient_accumulation: bool = True,
480      clip_weight_min: Optional[float] = None,
481      clip_weight_max: Optional[float] = None,
482      weight_decay_factor: Optional[float] = None,
483      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
484      clip_gradient_min: Optional[float] = None,
485      clip_gradient_max: Optional[float] = None,
486  ):
487    """Optimization parameters for Adagrad.
488
489    Args:
490      learning_rate: used for updating embedding table.
491      initial_accumulator: initial accumulator for Adagrad.
492      l1_regularization_strength: A float value, must be greater than or equal
493        to zero.
494      l2_regularization_strength: A float value, must be greater than or equal
495        to zero.
496      use_gradient_accumulation: setting this to `False` makes embedding
497        gradients calculation less accurate but faster. Please see
498        `optimization_parameters.proto` for details. for details.
499      clip_weight_min: the minimum value to clip by; None means -infinity.
500      clip_weight_max: the maximum value to clip by; None means +infinity.
501      weight_decay_factor: amount of weight decay to apply; None means that the
502        weights are not decayed.
503      multiply_weight_decay_factor_by_learning_rate: if true,
504        `weight_decay_factor` is multiplied by the current learning rate.
505      clip_gradient_min: the minimum value to clip by; None means -infinity.
506        Gradient accumulation must be set to true if this is set.
507      clip_gradient_max: the maximum value to clip by; None means +infinity.
508        Gradient accumulation must be set to true if this is set.
509    """
510    super(ProximalAdagradParameters, self).__init__(
511        learning_rate=learning_rate,
512        use_gradient_accumulation=use_gradient_accumulation,
513        clip_weight_min=clip_weight_min,
514        clip_weight_max=clip_weight_max,
515        weight_decay_factor=weight_decay_factor,
516        multiply_weight_decay_factor_by_learning_rate=(
517            multiply_weight_decay_factor_by_learning_rate),
518        clip_gradient_min=clip_gradient_min,
519        clip_gradient_max=clip_gradient_max,
520    )
521    if initial_accumulator <= 0:
522      raise ValueError('Adagrad initial_accumulator must be positive')
523    if l1_regularization_strength < 0.:
524      raise ValueError('l1_regularization_strength must be greater than or '
525                       'equal to 0. got {}.'.format(l1_regularization_strength))
526
527    if l2_regularization_strength < 0.:
528      raise ValueError('l2_regularization_strength must be greater than or '
529                       'equal to 0. got {}.'.format(l2_regularization_strength))
530
531    self.initial_accumulator = initial_accumulator
532    self.l1_regularization_strength = l1_regularization_strength
533    self.l2_regularization_strength = l2_regularization_strength
534
535
536@tf_export(v1=['tpu.experimental.AdamParameters'])
537class AdamParameters(_OptimizationParameters):
538  """Optimization parameters for Adam with TPU embeddings.
539
540  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
541  `optimization_parameters` argument to set the optimizer and its parameters.
542  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
543  for more details.
544
545  ```
546  estimator = tf.estimator.tpu.TPUEstimator(
547      ...
548      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
549          ...
550          optimization_parameters=tf.tpu.experimental.AdamParameters(0.1),
551          ...))
552  ```
553
554  """
555
556  def __init__(
557      self,
558      learning_rate: float,
559      beta1: float = 0.9,
560      beta2: float = 0.999,
561      epsilon: float = 1e-08,
562      lazy_adam: bool = True,
563      sum_inside_sqrt: bool = True,
564      use_gradient_accumulation: bool = True,
565      clip_weight_min: Optional[float] = None,
566      clip_weight_max: Optional[float] = None,
567      weight_decay_factor: Optional[float] = None,
568      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
569      clip_gradient_min: Optional[float] = None,
570      clip_gradient_max: Optional[float] = None,
571  ):
572    """Optimization parameters for Adam.
573
574    Args:
575      learning_rate: a floating point value. The learning rate.
576      beta1: A float value. The exponential decay rate for the 1st moment
577        estimates.
578      beta2: A float value. The exponential decay rate for the 2nd moment
579        estimates.
580      epsilon: A small constant for numerical stability.
581      lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See
582        `optimization_parameters.proto` for details.
583      sum_inside_sqrt: This improves training speed. Please see
584        `optimization_parameters.proto` for details.
585      use_gradient_accumulation: setting this to `False` makes embedding
586        gradients calculation less accurate but faster. Please see
587        `optimization_parameters.proto` for details.
588      clip_weight_min: the minimum value to clip by; None means -infinity.
589      clip_weight_max: the maximum value to clip by; None means +infinity.
590      weight_decay_factor: amount of weight decay to apply; None means that the
591        weights are not decayed.
592      multiply_weight_decay_factor_by_learning_rate: if true,
593        `weight_decay_factor` is multiplied by the current learning rate.
594      clip_gradient_min: the minimum value to clip by; None means -infinity.
595        Gradient accumulation must be set to true if this is set.
596      clip_gradient_max: the maximum value to clip by; None means +infinity.
597        Gradient accumulation must be set to true if this is set.
598    """
599    super(AdamParameters, self).__init__(
600        learning_rate=learning_rate,
601        use_gradient_accumulation=use_gradient_accumulation,
602        clip_weight_min=clip_weight_min,
603        clip_weight_max=clip_weight_max,
604        weight_decay_factor=weight_decay_factor,
605        multiply_weight_decay_factor_by_learning_rate=(
606            multiply_weight_decay_factor_by_learning_rate),
607        clip_gradient_min=clip_gradient_min,
608        clip_gradient_max=clip_gradient_max,
609    )
610    if beta1 < 0. or beta1 >= 1.:
611      raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
612    if beta2 < 0. or beta2 >= 1.:
613      raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
614    if epsilon <= 0.:
615      raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
616    if not use_gradient_accumulation and not lazy_adam:
617      raise ValueError(
618          'When disabling Lazy Adam, gradient accumulation must be used.')
619
620    self.beta1 = beta1
621    self.beta2 = beta2
622    self.epsilon = epsilon
623    self.lazy_adam = lazy_adam
624    self.sum_inside_sqrt = sum_inside_sqrt
625
626
627@tf_export(v1=['tpu.experimental.FtrlParameters'])
628class FtrlParameters(_OptimizationParameters):
629  """Optimization parameters for Ftrl with TPU embeddings.
630
631  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
632  `optimization_parameters` argument to set the optimizer and its parameters.
633  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
634  for more details.
635
636  ```
637  estimator = tf.estimator.tpu.TPUEstimator(
638      ...
639      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
640          ...
641          optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1),
642          ...))
643  ```
644
645  """
646
647  def __init__(
648      self,
649      learning_rate: float,
650      learning_rate_power: float = -0.5,
651      initial_accumulator_value: float = 0.1,
652      l1_regularization_strength: float = 0.0,
653      l2_regularization_strength: float = 0.0,
654      use_gradient_accumulation: bool = True,
655      clip_weight_min: Optional[float] = None,
656      clip_weight_max: Optional[float] = None,
657      weight_decay_factor: Optional[float] = None,
658      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
659      multiply_linear_by_learning_rate: bool = False,
660      beta: float = 0,
661      allow_zero_accumulator: bool = False,
662      clip_gradient_min: Optional[float] = None,
663      clip_gradient_max: Optional[float] = None,
664  ):
665    """Optimization parameters for Ftrl.
666
667    Implements FTRL as described in the following [paper](
668    https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)
669
670    Args:
671      learning_rate: a floating point value. The learning rate.
672      learning_rate_power: A float value, must be less or equal to zero.
673        Controls how the learning rate decreases during training. Use zero for a
674        fixed learning rate. See section 3.1 in the
675        [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
676      initial_accumulator_value: The starting value for accumulators. Only zero
677        or positive values are allowed.
678      l1_regularization_strength: A float value, must be greater than or equal
679        to zero.
680      l2_regularization_strength: A float value, must be greater than or equal
681        to zero.
682      use_gradient_accumulation: setting this to `False` makes embedding
683        gradients calculation less accurate but faster. Please see
684        `optimization_parameters.proto` for details. for details.
685      clip_weight_min: the minimum value to clip by; None means -infinity.
686      clip_weight_max: the maximum value to clip by; None means +infinity.
687      weight_decay_factor: amount of weight decay to apply; None means that the
688        weights are not decayed.
689      multiply_weight_decay_factor_by_learning_rate: if true,
690        `weight_decay_factor` is multiplied by the current learning rate.
691      multiply_linear_by_learning_rate: When true, multiplies the usages of the
692        linear slot in the weight update by the learning rate. This is useful
693        when ramping up learning rate from 0 (which would normally produce
694        NaNs).
695      beta: The beta parameter for FTRL.
696      allow_zero_accumulator: Changes the implementation of the square root to
697        allow for the case of initial_accumulator_value being zero. This will
698        cause a slight performance drop.
699      clip_gradient_min: the minimum value to clip by; None means -infinity.
700        Gradient accumulation must be set to true if this is set.
701      clip_gradient_max: the maximum value to clip by; None means +infinity.
702        Gradient accumulation must be set to true if this is set.
703    """
704    super(FtrlParameters, self).__init__(
705        learning_rate=learning_rate,
706        use_gradient_accumulation=use_gradient_accumulation,
707        clip_weight_min=clip_weight_min,
708        clip_weight_max=clip_weight_max,
709        weight_decay_factor=weight_decay_factor,
710        multiply_weight_decay_factor_by_learning_rate=(
711            multiply_weight_decay_factor_by_learning_rate),
712        clip_gradient_min=clip_gradient_min,
713        clip_gradient_max=clip_gradient_max,
714    )
715    if learning_rate_power > 0.:
716      raise ValueError('learning_rate_power must be less than or equal to 0. '
717                       'got {}.'.format(learning_rate_power))
718
719    if initial_accumulator_value < 0.:
720      raise ValueError('initial_accumulator_value must be greater than or equal'
721                       ' to 0. got {}.'.format(initial_accumulator_value))
722
723    if l1_regularization_strength < 0.:
724      raise ValueError('l1_regularization_strength must be greater than or '
725                       'equal to 0. got {}.'.format(l1_regularization_strength))
726
727    if l2_regularization_strength < 0.:
728      raise ValueError('l2_regularization_strength must be greater than or '
729                       'equal to 0. got {}.'.format(l2_regularization_strength))
730
731    self.learning_rate_power = learning_rate_power
732    self.initial_accumulator_value = initial_accumulator_value
733    self.initial_linear_value = 0.0
734    self.l1_regularization_strength = l1_regularization_strength
735    self.l2_regularization_strength = l2_regularization_strength
736    self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate
737    self.beta = beta
738    self.allow_zero_accumulator = allow_zero_accumulator
739
740
741class ProximalYogiParameters(_OptimizationParameters):
742  # pylint: disable=line-too-long
743  """Optimization parameters for Proximal Yogi with TPU embeddings.
744
745  Implements the Yogi optimizer as described in
746  [Adaptive Methods for Nonconvex
747  Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).
748
749  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
750  `optimization_parameters` argument to set the optimizer and its parameters.
751  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
752  for more details.
753  """
754
755  # pylint: enable=line-too-long
756
757  def __init__(
758      self,
759      learning_rate: float = 0.01,
760      beta1: float = 0.9,
761      beta2: float = 0.999,
762      epsilon: float = 1e-3,
763      l1_regularization_strength: float = 0.0,
764      l2_regularization_strength: float = 0.0,
765      initial_accumulator_value: float = 1e-6,
766      use_gradient_accumulation: bool = True,
767      clip_weight_min: Optional[float] = None,
768      clip_weight_max: Optional[float] = None,
769      weight_decay_factor: Optional[float] = None,
770      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
771      clip_gradient_min: Optional[float] = None,
772      clip_gradient_max: Optional[float] = None,
773  ):
774    """Optimization parameters for Proximal Yogi.
775
776    Args:
777      learning_rate: a floating point value. The learning rate.
778      beta1: A float value. The exponential decay rate for the 1st moment
779        estimates.
780      beta2: A float value. The exponential decay rate for the 2nd moment
781        estimates.
782      epsilon: A small constant for numerical stability.
783      l1_regularization_strength: A float value, must be greater than or equal
784        to zero.
785      l2_regularization_strength: A float value, must be greater than or equal
786        to zero.
787      initial_accumulator_value: The starting value for accumulators. Only zero
788        or positive values are allowed.
789      use_gradient_accumulation: setting this to `False` makes embedding
790        gradients calculation less accurate but faster. Please see
791        `optimization_parameters.proto` for details. for details.
792      clip_weight_min: the minimum value to clip by; None means -infinity.
793      clip_weight_max: the maximum value to clip by; None means +infinity.
794      weight_decay_factor: amount of weight decay to apply; None means that the
795        weights are not decayed.
796      multiply_weight_decay_factor_by_learning_rate: if true,
797        `weight_decay_factor` is multiplied by the current learning rate.
798      clip_gradient_min: the minimum value to clip by; None means -infinity.
799        Gradient accumulation must be set to true if this is set.
800      clip_gradient_max: the maximum value to clip by; None means +infinity.
801        Gradient accumulation must be set to true if this is set.
802    """
803    super(ProximalYogiParameters, self).__init__(
804        learning_rate=learning_rate,
805        use_gradient_accumulation=use_gradient_accumulation,
806        clip_weight_min=clip_weight_min,
807        clip_weight_max=clip_weight_max,
808        weight_decay_factor=weight_decay_factor,
809        multiply_weight_decay_factor_by_learning_rate=(
810            multiply_weight_decay_factor_by_learning_rate),
811        clip_gradient_min=clip_gradient_min,
812        clip_gradient_max=clip_gradient_max,
813    )
814    if beta1 < 0. or beta1 >= 1.:
815      raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
816    if beta2 < 0. or beta2 >= 1.:
817      raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
818    if epsilon <= 0.:
819      raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
820    if l1_regularization_strength < 0.:
821      raise ValueError('l1_regularization_strength must be greater than or '
822                       'equal to 0. got {}.'.format(l1_regularization_strength))
823    if l2_regularization_strength < 0.:
824      raise ValueError('l2_regularization_strength must be greater than or '
825                       'equal to 0. got {}.'.format(l2_regularization_strength))
826
827    self.beta1 = beta1
828    self.beta2 = beta2
829    self.epsilon = epsilon
830    self.l1_regularization_strength = l1_regularization_strength
831    self.l2_regularization_strength = l2_regularization_strength
832    self.initial_accumulator_value = initial_accumulator_value
833
834
835class MomentumParameters(_OptimizationParameters):
836  """Optimization parameters for Momentum with TPU embeddings.
837
838  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
839  `optimization_parameters` argument to set the optimizer and its parameters.
840  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
841  for more details.
842
843  ```
844  estimator = tf.estimator.tpu.TPUEstimator(
845      ...
846      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
847          ...
848          optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1),
849          ...))
850  ```
851
852  """
853
854  def __init__(
855      self,
856      learning_rate: float,
857      momentum: float,
858      use_nesterov: bool = False,
859      use_gradient_accumulation: bool = True,
860      clip_weight_min: Optional[float] = None,
861      clip_weight_max: Optional[float] = None,
862      weight_decay_factor: Optional[float] = None,
863      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
864      clip_gradient_min: Optional[float] = None,
865      clip_gradient_max: Optional[float] = None,
866  ):
867    """Optimization parameters for momentum.
868
869    Args:
870      learning_rate: a floating point value. The learning rate.
871      momentum: A `Tensor` or a floating point value.  The momentum.
872      use_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al.,
873        2013). This implementation always computes gradients at the value of the
874        variable(s) passed to the optimizer. Using Nesterov Momentum makes the
875        variable(s) track the values called `theta_t + mu*v_t` in the paper.
876        This implementation is an approximation of the original formula, valid
877        for high values of momentum. It will compute the "adjusted gradient" in
878        NAG by assuming that the new gradient will be estimated by the current
879        average gradient plus the product of momentum and the change in the
880        average gradient.
881      use_gradient_accumulation: setting this to `False` makes embedding
882        gradients calculation less accurate but faster. Please see
883        `optimization_parameters.proto` for details.
884      clip_weight_min: the minimum value to clip by; None means -infinity.
885      clip_weight_max: the maximum value to clip by; None means +infinity.
886      weight_decay_factor: amount of weight decay to apply; None means that the
887        weights are not decayed.
888      multiply_weight_decay_factor_by_learning_rate: if true,
889        `weight_decay_factor` is multiplied by the current learning rate.
890      clip_gradient_min: the minimum value to clip by; None means -infinity.
891        Gradient accumulation must be set to true if this is set.
892      clip_gradient_max: the maximum value to clip by; None means +infinity.
893        Gradient accumulation must be set to true if this is set.
894    """
895    super(MomentumParameters, self).__init__(
896        learning_rate=learning_rate,
897        use_gradient_accumulation=use_gradient_accumulation,
898        clip_weight_min=clip_weight_min,
899        clip_weight_max=clip_weight_max,
900        weight_decay_factor=weight_decay_factor,
901        multiply_weight_decay_factor_by_learning_rate=(
902            multiply_weight_decay_factor_by_learning_rate),
903        clip_gradient_min=clip_gradient_min,
904        clip_gradient_max=clip_gradient_max,
905    )
906    self.momentum = momentum
907    self.use_nesterov = use_nesterov
908
909
910class RMSPropParameters(_OptimizationParameters):
911  """Optimization parameters for RMSProp with TPU embeddings.
912
913  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
914  `optimization_parameters` argument to set the optimizer and its parameters.
915  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
916  for more details.
917
918  ```
919  estimator = tf.estimator.tpu.TPUEstimator(
920      ...
921      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
922          ...
923          optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1),
924          ...))
925  ```
926
927  """
928
929  def __init__(
930      self,
931      learning_rate: float,
932      rho: float,
933      momentum: float,
934      epsilon: float,
935      use_gradient_accumulation: bool = True,
936      clip_weight_min: Optional[float] = None,
937      clip_weight_max: Optional[float] = None,
938      weight_decay_factor: Optional[float] = None,
939      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
940      clip_gradient_min: Optional[float] = None,
941      clip_gradient_max: Optional[float] = None,
942  ):
943    """Optimization parameters for RMS prop.
944
945    Args:
946      learning_rate: a floating point value. The learning rate.
947      rho: Discounting factor for the history/coming gradient
948      momentum: A scalar tensor.
949      epsilon: Small value to avoid zero denominator.
950      use_gradient_accumulation: setting this to `False` makes embedding
951        gradients calculation less accurate but faster. Please see
952        `optimization_parameters.proto` for details. for details.
953      clip_weight_min: the minimum value to clip by; None means -infinity.
954      clip_weight_max: the maximum value to clip by; None means +infinity.
955      weight_decay_factor: amount of weight decay to apply; None means that the
956        weights are not decayed.
957      multiply_weight_decay_factor_by_learning_rate: if true,
958        `weight_decay_factor` is multiplied by the current learning rate.
959      clip_gradient_min: the minimum value to clip by; None means -infinity.
960        Gradient accumulation must be set to true if this is set.
961      clip_gradient_max: the maximum value to clip by; None means +infinity.
962        Gradient accumulation must be set to true if this is set.
963    """
964    super(RMSPropParameters, self).__init__(
965        learning_rate=learning_rate,
966        use_gradient_accumulation=use_gradient_accumulation,
967        clip_weight_min=clip_weight_min,
968        clip_weight_max=clip_weight_max,
969        weight_decay_factor=weight_decay_factor,
970        multiply_weight_decay_factor_by_learning_rate=(
971            multiply_weight_decay_factor_by_learning_rate),
972        clip_gradient_min=clip_gradient_min,
973        clip_gradient_max=clip_gradient_max,
974    )
975    self.rho = rho
976    self.momentum = momentum
977    self.epsilon = epsilon
978
979
980@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters'])
981class StochasticGradientDescentParameters(_OptimizationParameters):
982  """Optimization parameters for stochastic gradient descent for TPU embeddings.
983
984  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
985  `optimization_parameters` argument to set the optimizer and its parameters.
986  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
987  for more details.
988
989  ```
990  estimator = tf.estimator.tpu.TPUEstimator(
991      ...
992      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
993          ...
994          optimization_parameters=(
995              tf.tpu.experimental.StochasticGradientDescentParameters(0.1))))
996  ```
997
998  """
999
1000  def __init__(
1001      self,
1002      learning_rate: float,
1003      clip_weight_min: Optional[float] = None,
1004      clip_weight_max: Optional[float] = None,
1005      weight_decay_factor: Optional[float] = None,
1006      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
1007      clip_gradient_min: Optional[float] = None,
1008      clip_gradient_max: Optional[float] = None,
1009  ):
1010    """Optimization parameters for stochastic gradient descent.
1011
1012    Args:
1013      learning_rate: a floating point value. The learning rate.
1014      clip_weight_min: the minimum value to clip by; None means -infinity.
1015      clip_weight_max: the maximum value to clip by; None means +infinity.
1016      weight_decay_factor: amount of weight decay to apply; None means that the
1017        weights are not decayed.
1018      multiply_weight_decay_factor_by_learning_rate: if true,
1019        `weight_decay_factor` is multiplied by the current learning rate.
1020      clip_gradient_min: the minimum value to clip by; None means -infinity.
1021      clip_gradient_max: the maximum value to clip by; None means +infinity.
1022    """
1023    # Gradient accumulation is generally a no-op for SGD, but if gradient
1024    # clipping is enabled, then we must also enable gradient accumulation.
1025    # In the other optimizers this up to the user, but we don't give the user
1026    # the option to turn gradient accumulation on or off for SGD.
1027    use_gradient_accumulation = False
1028    if (clip_gradient_min is not None or clip_gradient_max is not None):
1029      use_gradient_accumulation = True
1030    super(StochasticGradientDescentParameters, self).__init__(
1031        learning_rate=learning_rate,
1032        use_gradient_accumulation=use_gradient_accumulation,
1033        clip_weight_min=clip_weight_min,
1034        clip_weight_max=clip_weight_max,
1035        weight_decay_factor=weight_decay_factor,
1036        multiply_weight_decay_factor_by_learning_rate=(
1037            multiply_weight_decay_factor_by_learning_rate),
1038        clip_gradient_min=clip_gradient_min,
1039        clip_gradient_max=clip_gradient_max,
1040    )
1041
1042
1043class FrequencyEstimatorParameters(_OptimizationParameters):
1044  """Optimization parameters for Frequency Estimator TPU embeddings.
1045
1046  This is a non-standard optimizer, which returns the estimated frequency of
1047  lookup for the feature passed to it. It should only be used on a table of
1048  width 1. The gradient fed back to the TPU embedding should always be zero.
1049  This can be acomplished via using `tf.stop_gradients` on the feature before
1050  using it.
1051
1052  You must use the dynamic learning rate mechanism to set the 'learning rate'
1053  for this table to be the a float32 cast of the global training step counter.
1054
1055  See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more
1056  details on this optimizer.
1057
1058  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
1059  `optimization_parameters` argument to set the optimizer and its parameters.
1060  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
1061  for more details.
1062
1063  ```
1064  estimator = tf.estimator.tpu.TPUEstimator(
1065      ...
1066      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
1067          ...
1068          optimization_parameters=FrequencyEstimatorParameters(0.1),
1069          ...))
1070  ```
1071
1072  """
1073
1074  def __init__(self, tau: float, max_delta: float, outlier_threshold: float,
1075               weight_exponent: float):
1076    """Optimization parameters for frequency estimator.
1077
1078    Args:
1079      tau: Learning rate between (0, 1) that is used to update the array.
1080      max_delta: Maximum value of delta, the difference between the current
1081        global step and the last global step at which the row was sampled.
1082      outlier_threshold: Threshold used to determine whether the current update
1083        is an outlier.
1084      weight_exponent: The weight exponent used to transform the estimated delta
1085        into weights.
1086    """
1087    super(FrequencyEstimatorParameters, self).__init__(
1088        learning_rate=1.0,
1089        use_gradient_accumulation=True,
1090        clip_weight_min=None,
1091        clip_weight_max=None,
1092        weight_decay_factor=None,
1093        multiply_weight_decay_factor_by_learning_rate=None,
1094    )
1095    self.tau = tau
1096    self.max_delta = max_delta
1097    self.outlier_threshold = outlier_threshold
1098    self.weight_exponent = weight_exponent
1099
1100
1101DeviceConfig = collections.namedtuple('DeviceConfig',
1102                                      ['num_hosts', 'num_cores', 'job_name'])
1103
1104
1105class TPUEmbedding(object):
1106  """API for using TPU for embedding.
1107
1108    Example:
1109    ```
1110    table_config_user = tpu_embedding.TableConfig(
1111        vocabulary_size=4, dimension=2,
1112        initializer=initializer, combiner='mean')
1113    table_to_config_dict = {'video': table_config_video,
1114                          'user': table_config_user}
1115    feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'),
1116                              'favorited': tpu_embedding.FeatureConfig('video'),
1117                              'friends': tpu_embedding.FeatureConfig('user')}
1118    batch_size = 4
1119    num_hosts = 1
1120    optimization_parameters = tpu_embedding.AdagradParameters(1., 1.)
1121    mode = tpu_embedding.TRAINING
1122    embedding = tpu_embedding.TPUEmbedding(
1123        table_to_config_dict, feature_to_config_dict,
1124        batch_size, num_hosts, mode, optimization_parameters)
1125
1126    batch_size_per_core = embedding.batch_size_per_core
1127    sparse_features_list = []
1128    for host in hosts:
1129      with ops.device(host):
1130        for _ in range(embedding.num_cores_per_host):
1131          sparse_features = {}
1132          sparse_features['watched'] = sparse_tensor.SparseTensor(...)
1133          sparse_features['favorited'] = sparse_tensor.SparseTensor(...)
1134          sparse_features['friends'] = sparse_tensor.SparseTensor(...)
1135          sparse_features_list.append(sparse_features)
1136
1137    enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list)
1138    embedding_variables_and_ops = embedding.create_variables_and_ops()
1139
1140    def computation():
1141      activations = embedding.get_activations()
1142      loss = compute_loss(activations)
1143
1144      base_optimizer = gradient_descent.GradientDescentOptimizer(
1145          learning_rate=1)
1146      cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer(
1147          base_optimizer)
1148
1149      train_op = cross_shard_optimizer.minimize(loss)
1150      gradients = (
1151          tpu_embedding_gradient.get_gradients_through_compute_gradients(
1152              cross_shard_optimizer, loss, activations)
1153      send_gradients_op = embedding.generate_send_gradients_op(gradients)
1154      with ops.control_dependencies([train_op, send_gradients_op]):
1155        loss = array_ops.identity(loss)
1156
1157    loss = tpu.shard(computation,
1158                     num_shards=embedding.num_cores)
1159
1160    with self.test_session() as sess:
1161      sess.run(tpu.initialize_system(embedding_config=
1162                                     embedding.config_proto))
1163      sess.run(variables.global_variables_initializer())
1164      sess.run(embedding_variables_and_ops.load_ops())
1165      sess.run(enqueue_ops)
1166      loss_val = sess.run(loss)
1167    ```
1168
1169  Example with weight decay:
1170
1171  >>> def learning_rate_fn(global_step):
1172  ...   return tf.compat.v1.train.polynomial_decay(
1173  ...     learning_rate=5e-5,
1174  ...     global_step=global_step,
1175  ...     decay_steps=100000,
1176  ...     end_learning_rate=0.0)
1177  >>> wordpiece_table_config = TableConfig(
1178  ...   vocabulary_size=119547,
1179  ...   dimension=256,
1180  ...   learning_rate_fn=learning_rate_fn)
1181  >>> wordpiece_feature_config = FeatureConfig(
1182  ...   table_id='bert/embeddings/word_embeddings',
1183  ...   max_sequence_length=512)
1184  >>> optimization_parameters = AdamParameters(
1185  ...   learning_rate=5e-5,
1186  ...   epsilon=1e-6,
1187  ...   weight_decay_factor=0.01,
1188  ...   multiply_weight_decay_factor_by_learning_rate=True)
1189  >>> tpu_embedding = TPUEmbedding(
1190  ...  table_to_config_dict={
1191  ...    'bert/embeddings/word_embeddings': wordpiece_table_config,
1192  ...  },
1193  ...  feature_to_config_dict={'input_ids': wordpiece_feature_config},
1194  ...  batch_size=128,
1195  ...  mode=TRAINING,
1196  ...  optimization_parameters=optimization_parameters,
1197  ...  master='')
1198  >>> with tf.Graph().as_default():
1199  ...   init_tpu_op = tf.compat.v1.tpu.initialize_system(
1200  ...     embedding_config=tpu_embedding.config_proto)
1201  ...   tf.compat.v1.Session().run(init_tpu_op)
1202  """
1203
1204  # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
1205  # the feature should not be used to update embedding table (cr/204852758,
1206  # cr/204940540). Also, this can support different combiners for different
1207  # features within the same table.
1208  # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it
1209  # to `FeatureConfig`?
1210
1211  # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and
1212  # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec`
1213  # respectively?
1214
1215  # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate
1216  # for-loops around construction of inputs.
1217
1218  # `optimization_parameter` applies to all tables. If the need arises,
1219  # we can add `optimization_parameters` to `TableConfig` to override this
1220  # global setting.
1221  def __init__(self,
1222               table_to_config_dict,
1223               feature_to_config_dict,
1224               batch_size,
1225               mode,
1226               master=None,
1227               optimization_parameters=None,
1228               cluster_def=None,
1229               pipeline_execution_with_tensor_core=False,
1230               partition_strategy='div',
1231               profile_data_directory=None,
1232               device_config=None,
1233               master_job_name=None):
1234    """API for using TPU for embedding lookups.
1235
1236    Args:
1237      table_to_config_dict: A dictionary mapping from string of table name to
1238        `TableConfig`. Table refers to an embedding table, e.g. `params`
1239        argument to `tf.nn.embedding_lookup_sparse()`.
1240      feature_to_config_dict: A dictionary mapping from string of feature name
1241        to `FeatureConfig`. Feature refers to ids to lookup in embedding table,
1242        e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
1243      batch_size: An `int` representing the global batch size.
1244      mode: `TRAINING` or `INFERENCE`.
1245      master: A `string` representing the TensorFlow master to use.
1246      optimization_parameters: `AdagradParameters`, `AdamParameters`,
1247        `Stochasticgradientdescentparameters`. Must be set in training unless
1248        all tables specify their own optimizers. And it must be `None` in
1249        inference.
1250      cluster_def: A ClusterDef object describing the TPU cluster.
1251      pipeline_execution_with_tensor_core: setting this to `True` makes training
1252        faster, but trained model will be different if step N and step N+1
1253        involve the same set of embedding IDs. Please see
1254        `tpu_embedding_configuration.proto` for details.
1255      partition_strategy: A string, either 'mod' or 'div', specifying how to map
1256        the lookup id to the embedding tensor. For more information see
1257        `tf.nn.embedding_lookup_sparse`.
1258      profile_data_directory: Directory where embedding lookup statistics are
1259        stored. These statistics summarize information about the inputs to the
1260        embedding lookup operation, in particular, the average number of
1261        embedding IDs per example and how well the embedding IDs are load
1262        balanced across the system. The lookup statistics are used during TPU
1263        initialization for embedding table partitioning. Collection of lookup
1264        statistics is done at runtime by  profiling the embedding inputs: only
1265        3% of input samples are profiled to minimize host CPU overhead. Once
1266        a suitable number of samples are profiled, the lookup statistics are
1267        saved to table-specific files in the profile data directory generally
1268        at the end of a TPU training loop. The filename corresponding to each
1269        table is obtained by hashing table specific parameters (e.g., table
1270        name and number of features) and global configuration parameters (e.g.,
1271        sharding strategy and task count). The same profile data directory can
1272        be shared among several models to reuse embedding lookup statistics.
1273      device_config: A DeviceConfig instance, used when `master` and
1274        `cluster_def` are both `None`.
1275      master_job_name: if set, overrides the master job name used to schedule
1276        embedding ops.
1277
1278    Raises:
1279      ValueError: if any input is invalid.
1280    """
1281    if partition_strategy not in ('div', 'mod'):
1282      raise ValueError(
1283          'Invalid partition_strategy {}'.format(partition_strategy))
1284    self._partition_strategy = partition_strategy
1285
1286    self._profile_data_directory = profile_data_directory
1287
1288    _validate_table_to_config_dict(table_to_config_dict)
1289    # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
1290    self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
1291
1292    _validate_feature_to_config_dict(table_to_config_dict,
1293                                     feature_to_config_dict)
1294    self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict)
1295    self._table_to_features_dict, self._table_to_num_features_dict = (
1296        _create_table_to_features_and_num_features_dicts(
1297            self._feature_to_config_dict))
1298    self._combiners = _create_combiners(self._table_to_config_dict,
1299                                        self._table_to_features_dict)
1300
1301    self._batch_size = batch_size
1302
1303    if master is None and cluster_def is None:
1304      if device_config is None:
1305        raise ValueError('When master and cluster_def are both None,'
1306                         'device_config must be set but is not.')
1307      if device_config.num_cores % device_config.num_hosts:
1308        raise ValueError('num_hosts ({}) should divide num_cores ({}) '
1309                         'but does not.'.format(device_config.num_cores,
1310                                                device_config.num_hosts))
1311      self._num_hosts = device_config.num_hosts
1312      self._num_cores = device_config.num_cores
1313      self._num_cores_per_host = self._num_cores // self._num_hosts
1314      self._hosts = [
1315          '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i)
1316          for i in range(self._num_hosts)
1317      ]
1318    else:
1319      tpu_system_metadata = (
1320          tpu_system_metadata_lib._query_tpu_system_metadata(  # pylint: disable=protected-access
1321              master,
1322              cluster_def=cluster_def))
1323      if tpu_system_metadata.num_cores == 0:
1324        raise ValueError('TPUEmbedding needs TPUs, but master {} does not have '
1325                         'TPUs.'.format(master))
1326      self._num_hosts = tpu_system_metadata.num_hosts
1327      if master_job_name is None:
1328        try:
1329          master_job_name = tpu_system_metadata_lib.master_job(
1330              master, cluster_def)
1331        except ValueError as e:
1332          raise ValueError(str(e) + ' Please specify a master_job_name.')
1333      self._hosts = []
1334      for device in tpu_system_metadata.devices:
1335        if 'device:CPU:' in device.name and (master_job_name is None or
1336                                             master_job_name in device.name):
1337          self._hosts.append(device.name)
1338      self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host
1339      self._num_cores = tpu_system_metadata.num_cores
1340
1341    _validate_batch_size(self._batch_size, self._num_cores)
1342    self._batch_size_per_core = self._batch_size // self._num_cores
1343
1344    # TODO(shizhiw): remove `mode`?
1345    if mode == TRAINING:
1346      _validate_optimization_parameters(optimization_parameters,
1347                                        self._table_to_config_dict)
1348      self._optimization_parameters = optimization_parameters
1349    elif mode == INFERENCE:
1350      if optimization_parameters is not None:
1351        raise ValueError('`optimization_parameters` should be `None` '
1352                         'for inference mode.')
1353      self._optimization_parameters = (StochasticGradientDescentParameters(1.))
1354    else:
1355      raise ValueError('`mode` only supports {} and {}; got {}.'.format(
1356          TRAINING, INFERENCE, mode))
1357    self._mode = mode
1358
1359    # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler`
1360    # and create special handler for inference that inherits from
1361    # StochasticGradientDescentHandler with more user-friendly error message
1362    # on get_slot().
1363    self._optimizer_handler_dict = self._get_optimizer_handler_by_table()
1364
1365    self._pipeline_execution_with_tensor_core = (
1366        pipeline_execution_with_tensor_core)
1367    self._learning_rate_fn = list(
1368        set(c.learning_rate_fn
1369            for c in self._table_to_config_dict.values()
1370            if c.learning_rate_fn is not None))
1371    self._learning_rate_fn_to_tag = {
1372        fn: id for id, fn in enumerate(self._learning_rate_fn)
1373    }
1374
1375    self._config_proto = self._create_config_proto()
1376
1377  @property
1378  def hosts(self):
1379    """A list of device names for CPU hosts.
1380
1381    Returns:
1382      A list of device names for CPU hosts.
1383    """
1384    return copy.copy(self._hosts)
1385
1386  # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and
1387  # to be consistent with `tpu_embedding_configuration.proto`.
1388  @property
1389  def num_cores_per_host(self):
1390    """Number of TPU cores on a CPU host.
1391
1392    Returns:
1393      Number of TPU cores on a CPU host.
1394    """
1395    return self._num_cores_per_host
1396
1397  @property
1398  def num_cores(self):
1399    """Total number of TPU cores on all hosts.
1400
1401    Returns:
1402      Total number of TPU cores on all hosts.
1403    """
1404    return self._num_cores
1405
1406  @property
1407  def batch_size_per_core(self):
1408    """Batch size for each TPU core.
1409
1410    The sparse tensors in `sparse_features_list` to `generate_enqueue_ops`
1411       must have batch dimension equal to this.
1412
1413    Returns:
1414      Batch size for each TPU core.
1415    """
1416    return self._batch_size_per_core
1417
1418  @property
1419  def config_proto(self):
1420    """Create embedding config proto for `tpu.initialize_system()`.
1421
1422    Returns:
1423      an `TPUEmbeddingConfiguration` proto describing the desired
1424         configuration of the hardware embedding lookup tables, which
1425         is passed to `tpu.initialize_system()`.
1426    """
1427    return self._config_proto
1428
1429  @property
1430  def table_to_config_dict(self):
1431    return copy.copy(self._table_to_config_dict)
1432
1433  @property
1434  def feature_to_config_dict(self):
1435    return copy.copy(self._feature_to_config_dict)
1436
1437  @property
1438  def table_to_features_dict(self):
1439    return copy.copy(self._table_to_features_dict)
1440
1441  @property
1442  def optimization_parameters(self):
1443    return self._optimization_parameters
1444
1445  def _create_config_proto(self):
1446    """Create `TPUEmbeddingConfiguration`."""
1447    config_proto = elc.TPUEmbeddingConfiguration()
1448    for table in self._table_to_config_dict:
1449      table_descriptor = config_proto.table_descriptor.add()
1450      table_descriptor.name = table
1451
1452      table_config = self._table_to_config_dict[table]
1453      # For small tables, we pad to the number of hosts so that at least one
1454      # id will be assigned to each host.
1455      table_descriptor.vocabulary_size = max(table_config.vocabulary_size,
1456                                             len(self.hosts))
1457      table_descriptor.dimension = table_config.dimension
1458
1459      table_descriptor.num_features = self._table_to_num_features_dict[table]
1460
1461      optimization_parameters = (
1462          self._optimizer_handler_dict[table].get_optimization_parameters())
1463
1464      parameters = table_descriptor.optimization_parameters
1465      if table_config.learning_rate:
1466        parameters.learning_rate.constant = table_config.learning_rate
1467      elif table_config.learning_rate_fn:
1468        parameters.learning_rate.dynamic.tag = (
1469            self._learning_rate_fn_to_tag[table_config.learning_rate_fn])
1470      else:
1471        parameters.learning_rate.constant = (
1472            optimization_parameters.learning_rate)
1473      parameters.gradient_accumulation_status = (
1474          optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
1475          if optimization_parameters.use_gradient_accumulation else
1476          optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
1477
1478      if optimization_parameters.clip_gradient_min is not None:
1479        parameters.gradient_clipping_limits.lower.value = (
1480            optimization_parameters.clip_gradient_min)
1481      if optimization_parameters.clip_gradient_max is not None:
1482        parameters.gradient_clipping_limits.upper.value = (
1483            optimization_parameters.clip_gradient_max)
1484
1485      if optimization_parameters.clip_weight_min is not None:
1486        parameters.clipping_limits.lower.value = (
1487            optimization_parameters.clip_weight_min)
1488      if optimization_parameters.clip_weight_max is not None:
1489        parameters.clipping_limits.upper.value = (
1490            optimization_parameters.clip_weight_max)
1491      if optimization_parameters.weight_decay_factor:
1492        parameters.weight_decay_factor = (
1493            optimization_parameters.weight_decay_factor)
1494        if (optimization_parameters
1495            .multiply_weight_decay_factor_by_learning_rate):
1496          parameters.multiply_weight_decay_factor_by_learning_rate = True
1497      if table_config.hot_id_replication:
1498        parameters.hot_id_replication_configuration.status = (
1499            optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED)
1500      optimizer_handler = self._optimizer_handler_dict[table]
1501      optimizer_handler.set_optimization_parameters(table_descriptor)
1502
1503    config_proto.mode = self._mode
1504    config_proto.batch_size_per_tensor_core = self._batch_size_per_core
1505    config_proto.num_hosts = self._num_hosts
1506    config_proto.num_tensor_cores = self._num_cores
1507    config_proto.sharding_strategy = (
1508        elc.TPUEmbeddingConfiguration.DIV_DEFAULT
1509        if self._partition_strategy == 'div' else
1510        elc.TPUEmbeddingConfiguration.MOD)
1511    config_proto.pipeline_execution_with_tensor_core = (
1512        self._pipeline_execution_with_tensor_core)
1513    if self._profile_data_directory:
1514      config_proto.profile_data_directory = self._profile_data_directory
1515
1516    return config_proto
1517
1518  def create_variables_and_ops(self,
1519                               embedding_variable_name_by_table=None,
1520                               slot_variable_names_by_table=None):
1521    """Create embedding and slot variables, with ops to load and retrieve them.
1522
1523    N.B.: the retrieve embedding variables (including slot variables) ops are
1524    returned as lambda fn, as the call side might want to impose control
1525    dependencies between the TPU computation and retrieving actions. For
1526    example, the following code snippet ensures the TPU computation finishes
1527    first, and then we pull the variables back from TPU to CPU.
1528
1529    ```
1530    updates_ops = []
1531    with ops.control_dependencies([loss]):
1532      for op_fn in retrieve_parameters_op_fns:
1533        update_ops.append(op_fn())
1534    ```
1535
1536    Args:
1537      embedding_variable_name_by_table: A dictionary mapping from string of
1538        table name to string of embedding variable name. If `None`, defaults
1539        from `get_default_slot_variable_names()` will be used.
1540      slot_variable_names_by_table: A dictionary mapping from string of table
1541        name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If
1542        `None`, defaults from `get_default_slot_variable_names()` will be used.
1543
1544    Returns:
1545      `tpu_embedding.VariablesAndOps` with:
1546        A dictionary mapping from string of table name to embedding variables,
1547        A dictionary mapping from string of table name to AdagradSlotVariable,
1548         AdamSlotVariables etc with slot variables,
1549        A function which returns a list of ops to load embedding and slot
1550         variables from CPU to TPU.
1551        A function which returns a list of ops to retrieve embedding and slot
1552         variables from TPU to CPU.
1553    """
1554    embedding_variables_by_table = {}
1555    slot_variables_by_table = {}
1556    load_op_fns = []
1557    retrieve_op_fns = []
1558
1559    for i, table in enumerate(self._table_to_config_dict):
1560      if embedding_variable_name_by_table:
1561        embedding_variable_name = embedding_variable_name_by_table[table]
1562      else:
1563        embedding_variable_name = table
1564      if slot_variable_names_by_table:
1565        slot_variable_names = slot_variable_names_by_table[table]
1566      else:
1567        optimizer_handler = self._optimizer_handler_dict[table]
1568        slot_variable_names = (
1569            optimizer_handler.get_default_slot_variable_names(table))
1570
1571      # TODO(b/139144091): Multi-host support for mid-level API in
1572      #  eager context (TF 2.0)
1573      # Workaround below allows single-host use case in TF 2.0
1574      if context.executing_eagerly():
1575        device = ''
1576      else:
1577        device = _create_device_fn(self._hosts)
1578
1579      with ops.device(device):
1580        table_variables = _create_partitioned_variables(
1581            name=embedding_variable_name,
1582            num_hosts=self._num_hosts,
1583            vocabulary_size=self._table_to_config_dict[table].vocabulary_size,
1584            embedding_dimension=self._table_to_config_dict[table].dimension,
1585            initializer=self._table_to_config_dict[table].initializer,
1586            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
1587        embedding_variables_by_table[table] = table_variables
1588
1589        # Only loads embedding config to load/retrieve nodes for the first table
1590        # on the first host, other nodes would use config from the first node.
1591        config = None if i else self.config_proto.SerializeToString()
1592        slot_variables_for_table, load_ops_fn, retrieve_ops_fn = (
1593            self._optimizer_handler_dict[table].create_variables_and_ops(
1594                table, slot_variable_names, self._num_hosts,
1595                self._table_to_config_dict[table], table_variables, config))
1596        slot_variables_by_table[table] = slot_variables_for_table
1597        load_op_fns.append(load_ops_fn)
1598        retrieve_op_fns.append(retrieve_ops_fn)
1599
1600    def load_ops():
1601      """Calls and returns the load ops for each embedding table.
1602
1603      Returns:
1604        A list of ops to load embedding and slot variables from CPU to TPU.
1605      """
1606      load_ops_list = []
1607      for load_op_fn in load_op_fns:
1608        load_ops_list.extend(load_op_fn())
1609      return load_ops_list
1610
1611    def retrieve_ops():
1612      """Calls and returns the retrieve ops for each embedding table.
1613
1614      Returns:
1615        A list of ops to retrieve embedding and slot variables from TPU to CPU.
1616      """
1617      retrieve_ops_list = []
1618      for retrieve_op_fn in retrieve_op_fns:
1619        retrieve_ops_list.extend(retrieve_op_fn())
1620      return retrieve_ops_list
1621
1622    return VariablesAndOps(embedding_variables_by_table,
1623                           slot_variables_by_table, load_ops, retrieve_ops)
1624
1625  def generate_enqueue_ops(
1626      self,
1627      enqueue_datas_list,
1628      mode_override=None,
1629      ragged=False,
1630  ):
1631    """Generate enqueue ops.
1632
1633    Args:
1634      enqueue_datas_list: a list of dictionary mapping from string of feature
1635        names to EnqueueData. Each dictionary is for one TPU core. Dictionaries
1636        for the same host should be contiguous in the list.
1637      mode_override: A string input that overrides the mode specified in the
1638        TPUEmbeddingConfiguration. Supported values are {'unspecified',
1639        'inference', 'training', 'backward_pass_only'}. When set to
1640        'unspecified', the mode set in TPUEmbeddingConfiguration is used,
1641        otherwise mode_override is used (optional).
1642      ragged: If True, creates RaggedTensor enqueue ops rather than
1643        SparseTensor.
1644
1645    Returns:
1646      Ops to enqueue to TPU for embedding.
1647    """
1648    self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list)
1649    return [
1650        self._generate_enqueue_op(  # pylint: disable=g-complex-comprehension
1651            enqueue_datas,
1652            device_ordinal=i % self._num_cores_per_host,
1653            mode_override=mode_override,
1654            ragged=ragged,
1655        ) for i, enqueue_datas in enumerate(enqueue_datas_list)
1656    ]
1657
1658  def _validate_generate_enqueue_ops_enqueue_datas_list(self,
1659                                                        enqueue_datas_list):
1660    """Validate `enqueue_datas_list`."""
1661
1662    def _check_agreement(data, name, feature, enqueue_data):
1663      """Helper function to check device agreement."""
1664      if (data is not None and
1665          data.device != enqueue_data.embedding_indices.device):
1666        raise ValueError('Device of {0} does not agree with that of'
1667                         'embedding_indices for feature {1}.'.format(
1668                             name, feature))
1669
1670    feature_set = set(self._feature_to_config_dict.keys())
1671    contiguous_device = None
1672    for i, enqueue_datas in enumerate(enqueue_datas_list):
1673      used_feature_set = set(enqueue_datas.keys())
1674
1675      # Check features are valid.
1676      missing_feature_set = feature_set - used_feature_set
1677      if missing_feature_set:
1678        raise ValueError('`enqueue_datas_list[{}]` misses a feature that is '
1679                         'in `feature_to_config_dict`: {}.'.format(
1680                             i, missing_feature_set))
1681
1682      extra_feature_set = used_feature_set - feature_set
1683      if extra_feature_set:
1684        raise ValueError('`enqueue_datas_list[{}]` has a feature that is not '
1685                         'in `feature_to_config_dict`: {}.'.format(
1686                             i, extra_feature_set))
1687
1688      device = None
1689      device_feature = None
1690      for feature, enqueue_data in six.iteritems(enqueue_datas):
1691        combiner = self._table_to_config_dict[
1692            self._feature_to_config_dict[feature].table_id].combiner
1693
1694        if isinstance(enqueue_data, EnqueueData):
1695          if enqueue_data.sample_indices is None and combiner:
1696            logging.warn(
1697                'No sample indices set for features %f table %f but '
1698                'combiner is set to %s.', feature,
1699                self._feature_to_config_dict[feature].table_id, combiner)
1700          _check_agreement(enqueue_data.sample_indices, 'sample_indices',
1701                           feature, enqueue_data)
1702          _check_agreement(enqueue_data.aggregation_weights,
1703                           'aggregation_weights', feature, enqueue_data)
1704
1705        elif isinstance(enqueue_data, RaggedEnqueueData):
1706          if enqueue_data.sample_splits is None and combiner:
1707            logging.warn(
1708                'No sample splits set for features %f table %f but '
1709                'combiner is set to %s.', feature,
1710                self._feature_to_config_dict[feature].table_id, combiner)
1711          _check_agreement(enqueue_data.sample_splits, 'sample_splits', feature,
1712                           enqueue_data)
1713          _check_agreement(enqueue_data.aggregation_weights,
1714                           'aggregation_weights', feature, enqueue_data)
1715        else:
1716          raise ValueError(
1717              '`enqueue_datas_list[{}]` has a feature that is not mapped to '
1718              '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format(
1719                  i, feature))
1720        # Check all features are on the same device.
1721        if device is None:
1722          device = enqueue_data.embedding_indices.device
1723          device_feature = feature
1724        else:
1725          if device != enqueue_data.embedding_indices.device:
1726            raise ValueError('Devices are different between features in '
1727                             '`enqueue_datas_list[{}]`; '
1728                             'devices: {}, {}; features: {}, {}.'.format(
1729                                 i, device,
1730                                 enqueue_data.embedding_indices.device, feature,
1731                                 device_feature))
1732
1733      if i % self._num_cores_per_host:
1734        if device != contiguous_device:
1735          raise ValueError('We expect the `enqueue_datas` which are on the '
1736                           'same host to be contiguous in '
1737                           '`enqueue_datas_list`, '
1738                           '`enqueue_datas_list[{}]` is on device {}, '
1739                           'but is expected to be on device {}.'.format(
1740                               i, device, contiguous_device))
1741      else:
1742        contiguous_device = device
1743
1744  def _generate_enqueue_op(self,
1745                           enqueue_datas,
1746                           device_ordinal,
1747                           mode_override=None,
1748                           ragged=False):
1749    """Creates op for enqueuing batch to TPU."""
1750    enqueue_data0 = list(enqueue_datas.values())[0]
1751    with ops.colocate_with(enqueue_data0.embedding_indices):
1752      if ragged:
1753        # note that this is currently identical in behavior
1754        return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
1755            device_ordinal=device_ordinal,
1756            combiners=self._combiners,
1757            mode_override=mode_override,
1758            **self._format_for_tpu_embedding_ragged_tensor_batch(enqueue_datas))
1759      else:
1760        return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
1761            device_ordinal=device_ordinal,
1762            combiners=self._combiners,
1763            mode_override=mode_override,
1764            **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
1765
1766  def _format_for_tpu_embedding_ragged_tensor_batch(self, enqueue_datas):
1767    """Format sparse features for `enqueue_tpu_embedding_ragged_tensor_batch()`.
1768
1769    Args:
1770      enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding.
1771
1772    Returns:
1773      Dict of arguments for `enqueue_tpu_embedding_ragged_tensor_batch()`.
1774    """
1775
1776    kwargs = {
1777        'sample_splits': [],
1778        'embedding_indices': [],
1779        'aggregation_weights': [],
1780        'table_ids': [],
1781        'max_sequence_lengths': [],
1782    }
1783    int_zeros = array_ops.zeros((0,), dtype=dtypes.int64)
1784    float_zeros = array_ops.zeros((0,), dtype=dtypes.float32)
1785    for table_id, table in enumerate(self._table_to_features_dict):
1786      features = self._table_to_features_dict[table]
1787      for feature in features:
1788        enqueue_data = enqueue_datas[feature]
1789
1790        kwargs['sample_splits'].append(
1791            enqueue_data.sample_splits
1792            if enqueue_data.sample_splits is not None else int_zeros)
1793
1794        kwargs['aggregation_weights'].append(
1795            enqueue_data.aggregation_weights
1796            if enqueue_data.aggregation_weights is not None else float_zeros)
1797
1798        kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
1799
1800        kwargs['table_ids'].append(table_id)
1801        kwargs['max_sequence_lengths'].append(
1802            self._feature_to_config_dict[feature].max_sequence_length)
1803
1804    return kwargs
1805
1806  def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
1807    """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
1808
1809    Args:
1810      enqueue_datas: a `Dict` of `EnqueueData` objects for embedding.
1811
1812    Returns:
1813      Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
1814    """
1815    kwargs = {
1816        'sample_indices': [],
1817        'embedding_indices': [],
1818        'aggregation_weights': [],
1819        'table_ids': [],
1820        'max_sequence_lengths': [],
1821    }
1822    int_zeros = array_ops.zeros((0,), dtype=dtypes.int64)
1823    float_zeros = array_ops.zeros((0,), dtype=dtypes.float32)
1824    for table_id, table in enumerate(self._table_to_features_dict):
1825      features = self._table_to_features_dict[table]
1826      for feature in features:
1827        enqueue_data = enqueue_datas[feature]
1828
1829        kwargs['sample_indices'].append(
1830            enqueue_data.sample_indices
1831            if enqueue_data.sample_indices is not None else int_zeros)
1832
1833        kwargs['aggregation_weights'].append(
1834            enqueue_data.aggregation_weights
1835            if enqueue_data.aggregation_weights is not None else float_zeros)
1836
1837        kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
1838
1839        kwargs['table_ids'].append(table_id)
1840        kwargs['max_sequence_lengths'].append(
1841            self._feature_to_config_dict[feature].max_sequence_length)
1842
1843    return kwargs
1844
1845  def get_activations(self):
1846    """Get activations for features.
1847
1848    This should be called within `computation` that is passed to
1849      `tpu.replicate` and friends.
1850
1851    Returns:
1852      A dictionary mapping from `String` of feature name to `Tensor`
1853        of activation.
1854    """
1855    recv_activations = tpu_ops.recv_tpu_embedding_activations(
1856        num_outputs=len(self._table_to_config_dict),
1857        config=self._config_proto.SerializeToString())
1858
1859    activations = collections.OrderedDict()
1860    for table_id, table in enumerate(self._table_to_features_dict):
1861      features = self._table_to_features_dict[table]
1862      num_features = self._table_to_num_features_dict[table]
1863      feature_index = 0
1864      table_activations = array_ops.reshape(
1865          recv_activations[table_id],
1866          [self.batch_size_per_core, num_features, -1])
1867      for feature in features:
1868        seq_length = self._feature_to_config_dict[feature].max_sequence_length
1869        if not seq_length:
1870          activations[feature] = table_activations[:, feature_index, :]
1871          feature_index = feature_index + 1
1872        else:
1873          activations[feature] = (
1874              table_activations[:,
1875                                feature_index:(feature_index + seq_length), :])
1876          feature_index = feature_index + seq_length
1877
1878    return activations
1879
1880  def generate_send_gradients_op(self, feature_to_gradient_dict, step=None):
1881    """Send gradient to TPU embedding.
1882
1883    Args:
1884      feature_to_gradient_dict: dict mapping feature names to gradient wrt
1885        activations.
1886      step: the current global step, used for dynamic learning rate.
1887
1888    Returns:
1889      SendTPUEmbeddingGradients Op.
1890
1891    Raises:
1892      RuntimeError: If `mode` is not `TRAINING`.
1893    """
1894    if self._mode != TRAINING:
1895      raise RuntimeError('Only in training mode gradients need to '
1896                         'be sent to TPU embedding; got mode {}.'.format(
1897                             self._mode))
1898    if step is None and self._learning_rate_fn:
1899      raise ValueError('There are dynamic learning rates but step is None.')
1900
1901    gradients = []
1902    for table in self._table_to_features_dict:
1903      features = self._table_to_features_dict[table]
1904      table_gradients = []
1905      for feature in features:
1906        gradient = feature_to_gradient_dict[feature]
1907        # Expand dims for non-sequence feature to match sequence features.
1908        if gradient.shape.ndims == 2:
1909          gradient = array_ops.expand_dims(gradient, 1)
1910        table_gradients.append(gradient)
1911      interleaved_table_grads = array_ops.reshape(
1912          array_ops.concat(table_gradients, axis=1),
1913          [-1, array_ops.shape(table_gradients[0])[-1]])
1914      gradients.append(interleaved_table_grads)
1915
1916    return tpu_ops.send_tpu_embedding_gradients(
1917        inputs=gradients,
1918        learning_rates=[
1919            math_ops.cast(fn(step), dtype=dtypes.float32)
1920            for fn in self._learning_rate_fn
1921        ],
1922        config=self.config_proto.SerializeToString())
1923
1924  def _get_optimizer_handler_by_table(self):
1925    optimizer_handlers = {}
1926    for table, table_config in self.table_to_config_dict.items():
1927      if table_config.optimization_parameters is not None:
1928        optimizer = table_config.optimization_parameters
1929      else:
1930        optimizer = self._optimization_parameters
1931      optimizer_handlers[table] = _get_optimization_handler(optimizer)
1932
1933    return optimizer_handlers
1934
1935
1936def _validate_table_to_config_dict(table_to_config_dict):
1937  """Validate `table_to_config_dict`."""
1938  for k, v in six.iteritems(table_to_config_dict):
1939    if not isinstance(v, TableConfig):
1940      raise ValueError('Value of `table_to_config_dict` must be of type '
1941                       '`TableConfig`, got {} for {}.'.format(type(v), k))
1942
1943
1944def _validate_feature_to_config_dict(table_to_config_dict,
1945                                     feature_to_config_dict):
1946  """Validate `feature_to_config_dict`."""
1947  used_table_set = set(
1948      [feature.table_id for feature in feature_to_config_dict.values()])
1949  table_set = set(table_to_config_dict.keys())
1950
1951  unused_table_set = table_set - used_table_set
1952  if unused_table_set:
1953    raise ValueError(
1954        '`table_to_config_dict` specifies table that is not '
1955        'used in `feature_to_config_dict`: {}.'.format(unused_table_set))
1956
1957  extra_table_set = used_table_set - table_set
1958  if extra_table_set:
1959    raise ValueError(
1960        '`feature_to_config_dict` refers to a table that is not '
1961        'specified in `table_to_config_dict`: {}.'.format(extra_table_set))
1962
1963
1964def _validate_batch_size(batch_size, num_cores):
1965  if batch_size % num_cores:
1966    raise ValueError('`batch_size` is not a multiple of number of '
1967                     'cores. `batch_size`={}, `_num_cores`={}.'.format(
1968                         batch_size, num_cores))
1969
1970
1971def _validate_optimization_parameters(optimization_parameters,
1972                                      table_to_config_dict):
1973  """Validate global optimization_parameters and per table optimizers.
1974
1975  If global optimizer is `None`, all table optimizers should be non `None`.
1976
1977  Args:
1978      optimization_parameters: global optimizer provided in `TPUEmbedding`
1979        constructor.
1980      table_to_config_dict: A dictionary mapping from string of table name to
1981        `TableConfig`.
1982  """
1983  tbl_optimizer_missing = False
1984  for _, table_config in table_to_config_dict.items():
1985    if table_config.optimization_parameters is None:
1986      tbl_optimizer_missing = True
1987      break
1988
1989  if optimization_parameters:
1990    if not isinstance(optimization_parameters, _OptimizationParameters):
1991      raise ValueError('`optimization_parameters` must inherit from '
1992                       '`_OptimizationParameters`. '
1993                       '`type(optimization_parameters)`={}'.format(
1994                           type(optimization_parameters)))
1995  else:
1996    # Missing global optimization_parameters.
1997    if tbl_optimizer_missing:
1998      raise ValueError('`optimization_parameters` is missing.')
1999
2000
2001class _OptimizerHandler(object):
2002  """Interface class for handling optimizer specific logic."""
2003
2004  def __init__(self, optimization_parameters):
2005    self._optimization_parameters = optimization_parameters
2006
2007  def get_optimization_parameters(self):
2008    return self._optimization_parameters
2009
2010  def set_optimization_parameters(self, table_descriptor):
2011    raise NotImplementedError()
2012
2013  def get_default_slot_variable_names(self, table):
2014    raise NotImplementedError()
2015
2016  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2017                               table_config, table_variables, config_proto):
2018    raise NotImplementedError()
2019
2020
2021class _AdagradHandler(_OptimizerHandler):
2022  """Handles Adagrad specific logic."""
2023
2024  def set_optimization_parameters(self, table_descriptor):
2025    table_descriptor.optimization_parameters.adagrad.SetInParent()
2026
2027  def get_default_slot_variable_names(self, table):
2028    return AdagradSlotVariableName('{}/{}'.format(table, 'Adagrad'))
2029
2030  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2031                               table_config, table_variables, config_proto):
2032    accumulator_initializer = init_ops.constant_initializer(
2033        self._optimization_parameters.initial_accumulator)
2034    accumulator_variables = _create_partitioned_variables(
2035        name=slot_variable_names.accumulator,
2036        num_hosts=num_hosts,
2037        vocabulary_size=table_config.vocabulary_size,
2038        embedding_dimension=table_config.dimension,
2039        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2040        initializer=accumulator_initializer)
2041    slot_variables = AdagradSlotVariable(accumulator_variables)
2042
2043    def load_ops_fn():
2044      """Returns the retrieve ops for AdaGrad embedding tables.
2045
2046      Returns:
2047        A list of ops to load embedding and slot variables from CPU to TPU.
2048      """
2049      config = config_proto
2050      load_op_list = []
2051      for host_id, table_variable, accumulator_variable in zip(
2052          range(num_hosts), table_variables, accumulator_variables):
2053        with ops.colocate_with(table_variable):
2054          load_parameters_op = (
2055              tpu_ops.load_tpu_embedding_adagrad_parameters(
2056                  parameters=table_variable,
2057                  accumulators=accumulator_variable,
2058                  table_name=table,
2059                  num_shards=num_hosts,
2060                  shard_id=host_id,
2061                  config=config))
2062        config = None
2063        load_op_list.append(load_parameters_op)
2064      return load_op_list
2065
2066    def retrieve_ops_fn():
2067      """Returns the retrieve ops for AdaGrad embedding tables.
2068
2069      Returns:
2070        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2071      """
2072      config = config_proto
2073      retrieve_op_list = []
2074      for host_id, table_variable, accumulator_variable in (zip(
2075          range(num_hosts), table_variables, accumulator_variables)):
2076        with ops.colocate_with(table_variable):
2077          retrieved_table, retrieved_accumulator = (
2078              tpu_ops.retrieve_tpu_embedding_adagrad_parameters(
2079                  table_name=table,
2080                  num_shards=num_hosts,
2081                  shard_id=host_id,
2082                  config=config))
2083          retrieve_parameters_op = control_flow_ops.group(
2084              state_ops.assign(table_variable, retrieved_table),
2085              state_ops.assign(accumulator_variable, retrieved_accumulator))
2086        config = None
2087        retrieve_op_list.append(retrieve_parameters_op)
2088      return retrieve_op_list
2089
2090    return slot_variables, load_ops_fn, retrieve_ops_fn
2091
2092
2093class _ProximalAdagradHandler(_OptimizerHandler):
2094  """Handles ProximalAdagrad specific logic."""
2095
2096  def set_optimization_parameters(self, table_descriptor):
2097    table_descriptor.optimization_parameters.proximal_adagrad.SetInParent()
2098    table_descriptor.optimization_parameters.proximal_adagrad.l1 = (
2099        self._optimization_parameters.l1_regularization_strength)
2100    table_descriptor.optimization_parameters.proximal_adagrad.l2 = (
2101        self._optimization_parameters.l2_regularization_strength)
2102
2103  def get_default_slot_variable_names(self, table):
2104    return ProximalAdagradSlotVariableName('{}/{}'.format(
2105        table, 'ProximalAdagrad'))
2106
2107  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2108                               table_config, table_variables, config_proto):
2109    accumulator_initializer = init_ops.constant_initializer(
2110        self._optimization_parameters.initial_accumulator)
2111    accumulator_variables = _create_partitioned_variables(
2112        name=slot_variable_names.accumulator,
2113        num_hosts=num_hosts,
2114        vocabulary_size=table_config.vocabulary_size,
2115        embedding_dimension=table_config.dimension,
2116        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2117        initializer=accumulator_initializer)
2118    slot_variables = ProximalAdagradSlotVariable(accumulator_variables)
2119
2120    def load_ops_fn():
2121      """Returns the retrieve ops for Proximal AdaGrad embedding tables.
2122
2123      Returns:
2124        A list of ops to load embedding and slot variables from CPU to TPU.
2125      """
2126      config = config_proto
2127      load_op_list = []
2128      for host_id, table_variable, accumulator_variable in zip(
2129          range(num_hosts), table_variables, accumulator_variables):
2130        with ops.colocate_with(table_variable):
2131          load_parameters_op = (
2132              tpu_ops.load_tpu_embedding_proximal_adagrad_parameters(
2133                  parameters=table_variable,
2134                  accumulators=accumulator_variable,
2135                  table_name=table,
2136                  num_shards=num_hosts,
2137                  shard_id=host_id,
2138                  config=config))
2139        config = None
2140        load_op_list.append(load_parameters_op)
2141      return load_op_list
2142
2143    def retrieve_ops_fn():
2144      """Returns the retrieve ops for Proximal AdaGrad embedding tables.
2145
2146      Returns:
2147        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2148      """
2149      config = config_proto
2150      retrieve_op_list = []
2151      for host_id, table_variable, accumulator_variable in (zip(
2152          range(num_hosts), table_variables, accumulator_variables)):
2153        with ops.colocate_with(table_variable):
2154          retrieved_table, retrieved_accumulator = (
2155              tpu_ops.retrieve_tpu_embedding_proximal_adagrad_parameters(
2156                  table_name=table,
2157                  num_shards=num_hosts,
2158                  shard_id=host_id,
2159                  config=config))
2160          retrieve_parameters_op = control_flow_ops.group(
2161              state_ops.assign(table_variable, retrieved_table),
2162              state_ops.assign(accumulator_variable, retrieved_accumulator))
2163        config = None
2164        retrieve_op_list.append(retrieve_parameters_op)
2165      return retrieve_op_list
2166
2167    return slot_variables, load_ops_fn, retrieve_ops_fn
2168
2169
2170class _AdamHandler(_OptimizerHandler):
2171  """Handles Adam specific logic."""
2172
2173  def set_optimization_parameters(self, table_descriptor):
2174    table_descriptor.optimization_parameters.adam.beta1 = (
2175        self._optimization_parameters.beta1)
2176    table_descriptor.optimization_parameters.adam.beta2 = (
2177        self._optimization_parameters.beta2)
2178    table_descriptor.optimization_parameters.adam.epsilon = (
2179        self._optimization_parameters.epsilon)
2180    table_descriptor.optimization_parameters.adam.use_non_lazy_adam = (
2181        not self._optimization_parameters.lazy_adam)
2182    table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = (
2183        self._optimization_parameters.sum_inside_sqrt)
2184
2185  def get_default_slot_variable_names(self, table):
2186    return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'),
2187                                 '{}/{}/v'.format(table, 'Adam'))
2188
2189  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2190                               table_config, table_variables, config_proto):
2191    m_initializer = init_ops.zeros_initializer()
2192    m_variables = _create_partitioned_variables(
2193        name=slot_variable_names.m,
2194        num_hosts=num_hosts,
2195        vocabulary_size=table_config.vocabulary_size,
2196        embedding_dimension=table_config.dimension,
2197        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2198        initializer=m_initializer)
2199    v_initializer = init_ops.zeros_initializer()
2200    v_variables = _create_partitioned_variables(
2201        name=slot_variable_names.v,
2202        num_hosts=num_hosts,
2203        vocabulary_size=table_config.vocabulary_size,
2204        embedding_dimension=table_config.dimension,
2205        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2206        initializer=v_initializer)
2207    slot_variables = AdamSlotVariables(m_variables, v_variables)
2208
2209    def load_ops_fn():
2210      """Returns the retrieve ops for AdaGrad embedding tables.
2211
2212      Returns:
2213        A list of ops to load embedding and slot variables from CPU to TPU.
2214      """
2215      load_op_list = []
2216      config = config_proto
2217      for host_id, table_variable, m_variable, v_variable in (zip(
2218          range(num_hosts), table_variables, m_variables, v_variables)):
2219        with ops.colocate_with(table_variable):
2220          load_parameters_op = (
2221              tpu_ops.load_tpu_embedding_adam_parameters(
2222                  parameters=table_variable,
2223                  momenta=m_variable,
2224                  velocities=v_variable,
2225                  table_name=table,
2226                  num_shards=num_hosts,
2227                  shard_id=host_id,
2228                  config=config))
2229        # Set config to None to enforce that config is only loaded to the first
2230        # table.
2231        config = None
2232        load_op_list.append(load_parameters_op)
2233      return load_op_list
2234
2235    def retrieve_ops_fn():
2236      """Returns the retrieve ops for Adam embedding tables.
2237
2238      Returns:
2239        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2240      """
2241      retrieve_op_list = []
2242      config = config_proto
2243      for host_id, table_variable, m_variable, v_variable in (zip(
2244          range(num_hosts), table_variables, m_variables, v_variables)):
2245        with ops.colocate_with(table_variable):
2246          retrieved_table, retrieved_m, retrieved_v = (
2247              tpu_ops.retrieve_tpu_embedding_adam_parameters(
2248                  table_name=table,
2249                  num_shards=num_hosts,
2250                  shard_id=host_id,
2251                  config=config))
2252          retrieve_parameters_op = control_flow_ops.group(
2253              state_ops.assign(table_variable, retrieved_table),
2254              state_ops.assign(m_variable, retrieved_m),
2255              state_ops.assign(v_variable, retrieved_v))
2256        config = None
2257        retrieve_op_list.append(retrieve_parameters_op)
2258      return retrieve_op_list
2259
2260    return slot_variables, load_ops_fn, retrieve_ops_fn
2261
2262
2263class _FtrlHandler(_OptimizerHandler):
2264  """Handles Ftrl specific logic."""
2265
2266  def set_optimization_parameters(self, table_descriptor):
2267    table_descriptor.optimization_parameters.ftrl.lr_power = (
2268        self._optimization_parameters.learning_rate_power)
2269    table_descriptor.optimization_parameters.ftrl.l1 = (
2270        self._optimization_parameters.l1_regularization_strength)
2271    table_descriptor.optimization_parameters.ftrl.l2 = (
2272        self._optimization_parameters.l2_regularization_strength)
2273    table_descriptor.optimization_parameters.ftrl.multiply_linear_by_lr = (
2274        self._optimization_parameters.multiply_linear_by_learning_rate)
2275    table_descriptor.optimization_parameters.ftrl.beta = (
2276        self._optimization_parameters.beta)
2277    table_descriptor.optimization_parameters.ftrl.allow_zero_accumulator = (
2278        self._optimization_parameters.allow_zero_accumulator)
2279
2280  def get_default_slot_variable_names(self, table):
2281    # These match the default slot variable names created by
2282    # tf.train.FtrlOptimizer.
2283    return FtrlSlotVariableName(
2284        '{}/{}'.format(table, 'Ftrl'),  # accumulator
2285        '{}/{}'.format(table, 'Ftrl_1'))  # linear
2286
2287  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2288                               table_config, table_variables, config_proto):
2289    accumulator_initializer = init_ops.constant_initializer(
2290        self._optimization_parameters.initial_accumulator_value)
2291    accumulator_variables = _create_partitioned_variables(
2292        name=slot_variable_names.accumulator,
2293        num_hosts=num_hosts,
2294        vocabulary_size=table_config.vocabulary_size,
2295        embedding_dimension=table_config.dimension,
2296        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2297        initializer=accumulator_initializer)
2298    linear_initializer = init_ops.constant_initializer(
2299        self._optimization_parameters.initial_linear_value)
2300    linear_variables = _create_partitioned_variables(
2301        name=slot_variable_names.linear,
2302        num_hosts=num_hosts,
2303        vocabulary_size=table_config.vocabulary_size,
2304        embedding_dimension=table_config.dimension,
2305        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2306        initializer=linear_initializer)
2307    slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables)
2308
2309    def load_ops_fn():
2310      """Returns the retrieve ops for Ftrl embedding tables.
2311
2312      Returns:
2313        A list of ops to load embedding and slot variables from CPU to TPU.
2314      """
2315      config = config_proto
2316      load_op_list = []
2317      for host_id, table_variable, accumulator_variable, linear_variable in zip(
2318          range(num_hosts), table_variables, accumulator_variables,
2319          linear_variables):
2320        with ops.colocate_with(table_variable):
2321          load_parameters_op = (
2322              tpu_ops.load_tpu_embedding_ftrl_parameters(
2323                  parameters=table_variable,
2324                  accumulators=accumulator_variable,
2325                  linears=linear_variable,
2326                  table_name=table,
2327                  num_shards=num_hosts,
2328                  shard_id=host_id,
2329                  config=config))
2330        config = None
2331        load_op_list.append(load_parameters_op)
2332      return load_op_list
2333
2334    def retrieve_ops_fn():
2335      """Returns the retrieve ops for Ftrl embedding tables.
2336
2337      Returns:
2338        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2339      """
2340      config = config_proto
2341      retrieve_op_list = []
2342      for host_id, table_variable, accumulator_variable, linear_variable in zip(
2343          range(num_hosts), table_variables, accumulator_variables,
2344          linear_variables):
2345        with ops.colocate_with(table_variable):
2346          retrieved_table, retrieved_accumulator, retrieved_linear = (
2347              tpu_ops.retrieve_tpu_embedding_ftrl_parameters(
2348                  table_name=table,
2349                  num_shards=num_hosts,
2350                  shard_id=host_id,
2351                  config=config))
2352          retrieve_parameters_op = control_flow_ops.group(
2353              state_ops.assign(table_variable, retrieved_table),
2354              state_ops.assign(accumulator_variable, retrieved_accumulator),
2355              state_ops.assign(linear_variable, retrieved_linear))
2356        config = None
2357        retrieve_op_list.append(retrieve_parameters_op)
2358      return retrieve_op_list
2359
2360    return slot_variables, load_ops_fn, retrieve_ops_fn
2361
2362
2363class _ProximalYogiHandler(_OptimizerHandler):
2364  """Handles Proximal Yogi specific logic."""
2365
2366  def set_optimization_parameters(self, table_descriptor):
2367    table_descriptor.optimization_parameters.proximal_yogi.SetInParent()
2368    table_descriptor.optimization_parameters.proximal_yogi.beta1 = (
2369        self._optimization_parameters.beta1)
2370    table_descriptor.optimization_parameters.proximal_yogi.beta2 = (
2371        self._optimization_parameters.beta2)
2372    table_descriptor.optimization_parameters.proximal_yogi.epsilon = (
2373        self._optimization_parameters.epsilon)
2374    table_descriptor.optimization_parameters.proximal_yogi.l1 = (
2375        self._optimization_parameters.l1_regularization_strength)
2376    table_descriptor.optimization_parameters.proximal_yogi.l2 = (
2377        self._optimization_parameters.l2_regularization_strength)
2378
2379  def get_default_slot_variable_names(self, table):
2380    return ProximalYogiSlotVariableNames(
2381        '{}/{}'.format(table, 'ProximalYogi'),  # v
2382        '{}/{}_1'.format(table, 'ProximalYogi'))  # m
2383
2384  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2385                               table_config, table_variables, config_proto):
2386    v_initializer = init_ops.constant_initializer(
2387        self._optimization_parameters.initial_accumulator_value)
2388    v_variables = _create_partitioned_variables(
2389        name=slot_variable_names.v,
2390        num_hosts=num_hosts,
2391        vocabulary_size=table_config.vocabulary_size,
2392        embedding_dimension=table_config.dimension,
2393        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2394        initializer=v_initializer)
2395    m_initializer = init_ops.zeros_initializer()
2396    m_variables = _create_partitioned_variables(
2397        name=slot_variable_names.m,
2398        num_hosts=num_hosts,
2399        vocabulary_size=table_config.vocabulary_size,
2400        embedding_dimension=table_config.dimension,
2401        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2402        initializer=m_initializer)
2403    slot_variables = ProximalYogiSlotVariables(v_variables, m_variables)
2404
2405    def load_ops_fn():
2406      """Returns the load ops for Proximal Yogi embedding tables.
2407
2408      Returns:
2409        A list of ops to load embedding and slot variables from CPU to TPU.
2410      """
2411      load_op_list = []
2412      config = config_proto
2413      for host_id, table_variable, v_variable, m_variable in (zip(
2414          range(num_hosts), table_variables, v_variables, m_variables)):
2415        with ops.colocate_with(table_variable):
2416          load_parameters_op = (
2417              tpu_ops.load_tpu_embedding_proximal_yogi_parameters(
2418                  parameters=table_variable,
2419                  v=v_variable,
2420                  m=m_variable,
2421                  table_name=table,
2422                  num_shards=num_hosts,
2423                  shard_id=host_id,
2424                  config=config))
2425        # Set config to None to enforce that config is only loaded to the first
2426        # table.
2427        config = None
2428        load_op_list.append(load_parameters_op)
2429      return load_op_list
2430
2431    def retrieve_ops_fn():
2432      """Returns the retrieve ops for Proximal Yogi embedding tables.
2433
2434      Returns:
2435        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2436      """
2437      retrieve_op_list = []
2438      config = config_proto
2439      for host_id, table_variable, v_variable, m_variable in (zip(
2440          range(num_hosts), table_variables, v_variables, m_variables)):
2441        with ops.colocate_with(table_variable):
2442          retrieved_table, retrieved_v, retrieved_m = (
2443              tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters(
2444                  table_name=table,
2445                  num_shards=num_hosts,
2446                  shard_id=host_id,
2447                  config=config))
2448          retrieve_parameters_op = control_flow_ops.group(
2449              state_ops.assign(table_variable, retrieved_table),
2450              state_ops.assign(v_variable, retrieved_v),
2451              state_ops.assign(m_variable, retrieved_m))
2452        config = None
2453        retrieve_op_list.append(retrieve_parameters_op)
2454      return retrieve_op_list
2455
2456    return slot_variables, load_ops_fn, retrieve_ops_fn
2457
2458
2459class _MomentumHandler(_OptimizerHandler):
2460  """Handles Momentum specific logic."""
2461
2462  def set_optimization_parameters(self, table_descriptor):
2463    (table_descriptor.optimization_parameters.momentum.SetInParent())
2464    table_descriptor.optimization_parameters.momentum.momentum = (
2465        self._optimization_parameters.momentum)
2466    table_descriptor.optimization_parameters.momentum.use_nesterov = (
2467        self._optimization_parameters.use_nesterov)
2468
2469  def get_default_slot_variable_names(self, table):
2470    return MomentumSlotVariableName('{}/{}'.format(table, 'Momentum'))
2471
2472  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2473                               table_config, table_variables, config_proto):
2474
2475    momenta_initializer = init_ops.zeros_initializer()
2476    momenta_variables = _create_partitioned_variables(
2477        name=slot_variable_names.momenta,
2478        num_hosts=num_hosts,
2479        vocabulary_size=table_config.vocabulary_size,
2480        embedding_dimension=table_config.dimension,
2481        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2482        initializer=momenta_initializer)
2483    slot_variables = MomentumSlotVariable(momenta_variables)
2484
2485    def load_ops_fn():
2486      """Returns the retrieve ops for Momentum embedding tables.
2487
2488      Returns:
2489        A list of ops to load embedding and slot variables from CPU to TPU.
2490      """
2491      load_op_list = []
2492      config = config_proto
2493      for host_id, table_variable, momenta_variable in (zip(
2494          range(num_hosts), table_variables, momenta_variables)):
2495        with ops.colocate_with(table_variable):
2496          load_parameters_op = tpu_ops.load_tpu_embedding_momentum_parameters(
2497              parameters=table_variable,
2498              momenta=momenta_variable,
2499              table_name=table,
2500              num_shards=num_hosts,
2501              shard_id=host_id,
2502              config=config,
2503          )
2504        config = None
2505        load_op_list.append(load_parameters_op)
2506      return load_op_list
2507
2508    def retrieve_ops_fn():
2509      """Returns the retrieve ops for Momentum embedding tables.
2510
2511      Returns:
2512        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2513      """
2514      retrieve_op_list = []
2515      config = config_proto
2516      for host_id, table_variable, momenta_variable in (zip(
2517          range(num_hosts), table_variables, momenta_variables)):
2518        with ops.colocate_with(table_variable):
2519          retrieved_table, retrieved_momenta = (
2520              tpu_ops.retrieve_tpu_embedding_momentum_parameters(
2521                  table_name=table,
2522                  num_shards=num_hosts,
2523                  shard_id=host_id,
2524                  config=config,
2525              ))
2526          retrieve_parameters_op = control_flow_ops.group(
2527              state_ops.assign(table_variable, retrieved_table),
2528              state_ops.assign(momenta_variable, retrieved_momenta))
2529        config = None
2530        retrieve_op_list.append(retrieve_parameters_op)
2531      return retrieve_op_list
2532
2533    return slot_variables, load_ops_fn, retrieve_ops_fn
2534
2535
2536class _RMSPropHandler(_OptimizerHandler):
2537  """Handles RMS prop specific logic."""
2538
2539  def set_optimization_parameters(self, table_descriptor):
2540    (table_descriptor.optimization_parameters.rms_prop.SetInParent())
2541    table_descriptor.optimization_parameters.rms_prop.rho = (
2542        self._optimization_parameters.rho)
2543    table_descriptor.optimization_parameters.rms_prop.epsilon = (
2544        self._optimization_parameters.epsilon)
2545    table_descriptor.optimization_parameters.rms_prop.momentum = (
2546        self._optimization_parameters.momentum)
2547
2548  def get_default_slot_variable_names(self, table):
2549    return RMSPropSlotVariableNames('{}/{}/ms'.format(table, 'RMSProp'),
2550                                    '{}/{}/mom'.format(table, 'RMSProp'))
2551
2552  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2553                               table_config, table_variables, config_proto):
2554
2555    ms_variables = _create_partitioned_variables(
2556        name=slot_variable_names.ms,
2557        num_hosts=num_hosts,
2558        vocabulary_size=table_config.vocabulary_size,
2559        embedding_dimension=table_config.dimension,
2560        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2561        initializer=init_ops.zeros_initializer(),
2562    )
2563    mom_variables = _create_partitioned_variables(
2564        name=slot_variable_names.mom,
2565        num_hosts=num_hosts,
2566        vocabulary_size=table_config.vocabulary_size,
2567        embedding_dimension=table_config.dimension,
2568        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2569        initializer=init_ops.zeros_initializer(),
2570    )
2571    slot_variables = RMSPropSlotVariables(ms_variables, mom_variables)
2572
2573    def load_ops_fn():
2574      """Returns the retrieve ops for RMS Prop embedding tables.
2575
2576      Returns:
2577        A list of ops to load embedding and slot variables from CPU to TPU.
2578      """
2579      load_op_list = []
2580      config = config_proto
2581      for host_id, table_variable, ms_variable, mom_variable in (zip(
2582          range(num_hosts), table_variables, ms_variables, mom_variables)):
2583        with ops.colocate_with(table_variable):
2584          load_parameters_op = tpu_ops.load_tpu_embedding_rms_prop_parameters(
2585              parameters=table_variable,
2586              ms=ms_variable,
2587              mom=mom_variable,
2588              table_name=table,
2589              num_shards=num_hosts,
2590              shard_id=host_id,
2591              config=config,
2592          )
2593        config = None
2594        load_op_list.append(load_parameters_op)
2595      return load_op_list
2596
2597    def retrieve_ops_fn():
2598      """Returns the retrieve ops for RMS Prop embedding tables.
2599
2600      Returns:
2601        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2602      """
2603      retrieve_op_list = []
2604      config = config_proto
2605      for host_id, table_variable, ms_variable, mom_variable in (zip(
2606          range(num_hosts), table_variables, ms_variables, mom_variables)):
2607        with ops.colocate_with(table_variable):
2608          retrieved_table, retrieved_ms, retrieved_mom = (
2609              tpu_ops.retrieve_tpu_embedding_rms_prop_parameters(
2610                  table_name=table,
2611                  num_shards=num_hosts,
2612                  shard_id=host_id,
2613                  config=config,
2614              ))
2615          retrieve_parameters_op = control_flow_ops.group(
2616              state_ops.assign(table_variable, retrieved_table),
2617              state_ops.assign(ms_variable, retrieved_ms),
2618              state_ops.assign(mom_variable, retrieved_mom))
2619        config = None
2620        retrieve_op_list.append(retrieve_parameters_op)
2621      return retrieve_op_list
2622
2623    return slot_variables, load_ops_fn, retrieve_ops_fn
2624
2625
2626class _FrequencyEstimatorHandler(_OptimizerHandler):
2627  """Handles frequency estimator specific logic."""
2628
2629  def set_optimization_parameters(self, table_descriptor):
2630    table_descriptor.optimization_parameters.frequency_estimator.SetInParent()
2631    freq = table_descriptor.optimization_parameters.frequency_estimator
2632    freq.tau = self._optimization_parameters.tau
2633    freq.max_delta = self._optimization_parameters.max_delta
2634    freq.outlier_threshold = self._optimization_parameters.outlier_threshold
2635    freq.weight_exponent = self._optimization_parameters.weight_exponent
2636
2637  def get_default_slot_variable_names(self, table):
2638    return FrequencyEstimatorSlotVariableName(
2639        '{}/FrequencyEstimator'.format(table))
2640
2641  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2642                               table_config, table_variables, config_proto):
2643    if table_config.dimension != 1:
2644      raise ValueError('FrequencyEstimator tables should only have a dimension '
2645                       'of 1. Received dimension {}'.format(
2646                           table_config.dimension))
2647
2648    last_hit_step_variables = _create_partitioned_variables(
2649        name=slot_variable_names.last_hit_step,
2650        num_hosts=num_hosts,
2651        vocabulary_size=table_config.vocabulary_size,
2652        embedding_dimension=table_config.dimension,
2653        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2654        initializer=init_ops.zeros_initializer(),
2655    )
2656    slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables)
2657
2658    def load_ops_fn():
2659      """Returns the retrieve ops for Frequency Estimator embedding tables.
2660
2661      Returns:
2662        A list of ops to load embedding and slot variables from CPU to TPU.
2663      """
2664      load_op_list = []
2665      config = config_proto
2666      for host_id, table_variable, last_hit_step_variable in (zip(
2667          range(num_hosts), table_variables, last_hit_step_variables)):
2668        with ops.colocate_with(table_variable):
2669          load_parameters_op = (
2670              tpu_ops.load_tpu_embedding_frequency_estimator_parameters(
2671                  parameters=table_variable,
2672                  last_hit_step=last_hit_step_variable,
2673                  table_name=table,
2674                  num_shards=num_hosts,
2675                  shard_id=host_id,
2676                  config=config))
2677        config = None
2678        load_op_list.append(load_parameters_op)
2679      return load_op_list
2680
2681    def retrieve_ops_fn():
2682      """Returns the retrieve ops for Frequency Estimator embedding tables.
2683
2684      Returns:
2685        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2686      """
2687      retrieve_op_list = []
2688      config = config_proto
2689      for host_id, table_variable, last_hit_step_variable in (zip(
2690          range(num_hosts), table_variables, last_hit_step_variables)):
2691        with ops.colocate_with(table_variable):
2692          retrieved_table, retrieved_last_hit_step = (
2693              tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters(
2694                  table_name=table,
2695                  num_shards=num_hosts,
2696                  shard_id=host_id,
2697                  config=config,
2698              ))
2699          retrieve_parameters_op = control_flow_ops.group(
2700              state_ops.assign(table_variable, retrieved_table),
2701              state_ops.assign(last_hit_step_variable, retrieved_last_hit_step))
2702        config = None
2703        retrieve_op_list.append(retrieve_parameters_op)
2704      return retrieve_op_list
2705
2706    return slot_variables, load_ops_fn, retrieve_ops_fn
2707
2708
2709class _StochasticGradientDescentHandler(_OptimizerHandler):
2710  """Handles stochastic gradient descent specific logic."""
2711
2712  def set_optimization_parameters(self, table_descriptor):
2713    (table_descriptor.optimization_parameters.stochastic_gradient_descent
2714     .SetInParent())
2715
2716  def get_default_slot_variable_names(self, table):
2717    return None
2718
2719  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2720                               table_config, table_variables, config_proto):
2721    del table_config
2722
2723    def load_ops_fn():
2724      """Returns the retrieve ops for AdaGrad embedding tables.
2725
2726      Returns:
2727        A list of ops to load embedding and slot variables from CPU to TPU.
2728      """
2729      load_op_list = []
2730      config = config_proto
2731      for host_id, table_variable in enumerate(table_variables):
2732        with ops.colocate_with(table_variable):
2733          load_parameters_op = (
2734              tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters(
2735                  parameters=table_variable,
2736                  table_name=table,
2737                  num_shards=num_hosts,
2738                  shard_id=host_id,
2739                  config=config))
2740        config = None
2741        load_op_list.append(load_parameters_op)
2742      return load_op_list
2743
2744    def retrieve_ops_fn():
2745      """Returns the retrieve ops for SGD embedding tables.
2746
2747      Returns:
2748        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2749      """
2750      retrieve_op_list = []
2751      config = config_proto
2752      for host_id, table_variable in enumerate(table_variables):
2753        with ops.colocate_with(table_variable):
2754          retrieved_table = (
2755              tpu_ops
2756              .retrieve_tpu_embedding_stochastic_gradient_descent_parameters(
2757                  table_name=table,
2758                  num_shards=num_hosts,
2759                  shard_id=host_id,
2760                  config=config))
2761          retrieve_parameters_op = control_flow_ops.group(
2762              state_ops.assign(table_variable, retrieved_table))
2763        config = None
2764        retrieve_op_list.append(retrieve_parameters_op)
2765      return retrieve_op_list
2766
2767    return None, load_ops_fn, retrieve_ops_fn
2768
2769
2770def _get_optimization_handler(optimization_parameters):
2771  """Gets the optimization handler given the parameter type."""
2772  if isinstance(optimization_parameters, AdagradParameters):
2773    return _AdagradHandler(optimization_parameters)
2774  elif isinstance(optimization_parameters, ProximalAdagradParameters):
2775    return _ProximalAdagradHandler(optimization_parameters)
2776  elif isinstance(optimization_parameters, AdamParameters):
2777    return _AdamHandler(optimization_parameters)
2778  elif isinstance(optimization_parameters, FtrlParameters):
2779    return _FtrlHandler(optimization_parameters)
2780  elif isinstance(optimization_parameters, ProximalYogiParameters):
2781    return _ProximalYogiHandler(optimization_parameters)
2782  elif isinstance(optimization_parameters, StochasticGradientDescentParameters):
2783    return _StochasticGradientDescentHandler(optimization_parameters)
2784  elif isinstance(optimization_parameters, MomentumParameters):
2785    return _MomentumHandler(optimization_parameters)
2786  elif isinstance(optimization_parameters, RMSPropParameters):
2787    return _RMSPropHandler(optimization_parameters)
2788  elif isinstance(optimization_parameters, FrequencyEstimatorParameters):
2789    return _FrequencyEstimatorHandler(optimization_parameters)
2790  return NotImplementedError()
2791
2792
2793def _create_ordered_dict(d):
2794  """Create an OrderedDict from Dict."""
2795  return collections.OrderedDict((k, d[k]) for k in sorted(d))
2796
2797
2798def _create_combiners(table_to_config_dict, table_to_features_dict):
2799  """Create a per feature list of combiners, ordered by table."""
2800  combiners = []
2801  for table in table_to_config_dict:
2802    combiner = table_to_config_dict[table].combiner or 'sum'
2803    combiners.extend([combiner] * len(table_to_features_dict[table]))
2804  return combiners
2805
2806
2807def _create_table_to_features_and_num_features_dicts(feature_to_config_dict):
2808  """Create mapping from table to a list of its features."""
2809  table_to_features_dict_tmp = {}
2810  table_to_num_features_dict_tmp = {}
2811  for feature, feature_config in six.iteritems(feature_to_config_dict):
2812    if feature_config.table_id in table_to_features_dict_tmp:
2813      table_to_features_dict_tmp[feature_config.table_id].append(feature)
2814    else:
2815      table_to_features_dict_tmp[feature_config.table_id] = [feature]
2816      table_to_num_features_dict_tmp[feature_config.table_id] = 0
2817    if feature_config.max_sequence_length == 0:
2818      table_to_num_features_dict_tmp[feature_config.table_id] = (
2819          table_to_num_features_dict_tmp[feature_config.table_id] + 1)
2820    else:
2821      table_to_num_features_dict_tmp[feature_config.table_id] = (
2822          table_to_num_features_dict_tmp[feature_config.table_id] +
2823          feature_config.max_sequence_length)
2824
2825  table_to_features_dict = collections.OrderedDict()
2826  table_to_num_features_dict = collections.OrderedDict()
2827  for table in sorted(table_to_features_dict_tmp):
2828    table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table])
2829    table_to_num_features_dict[table] = table_to_num_features_dict_tmp[table]
2830  return table_to_features_dict, table_to_num_features_dict
2831
2832
2833def _create_device_fn(hosts):
2834  """Create device_fn() to use with _create_partitioned_variables()."""
2835
2836  def device_fn(op):
2837    """Returns the `device` for `op`."""
2838    part_match = re.match(r'.*/part_(\d+)(/|$)', op.name)
2839    dummy_match = re.match(r'.*dummy_(\d+).*', op.name)
2840    if not part_match and not dummy_match:
2841      raise RuntimeError(
2842          'Internal Error: Expected {} to contain /part_* or dummy_*'.format(
2843              op.name))
2844
2845    if part_match:
2846      idx = int(part_match.group(1))
2847    else:
2848      idx = int(dummy_match.group(1))  # pytype: disable=attribute-error
2849
2850    device = hosts[idx]
2851    logging.debug('assigning {} to {}.', op, device)
2852    return device
2853
2854  return device_fn
2855
2856
2857def _create_partitioned_variables(name,
2858                                  num_hosts,
2859                                  vocabulary_size,
2860                                  embedding_dimension,
2861                                  initializer,
2862                                  collections=None):  # pylint: disable=redefined-outer-name
2863  """Creates PartitionedVariables based on `num_hosts` for `table`."""
2864
2865  num_slices = min(vocabulary_size, num_hosts)
2866
2867  var_list = list(
2868      variable_scope.get_variable(
2869          name,
2870          shape=(vocabulary_size, embedding_dimension),
2871          partitioner=partitioned_variables.fixed_size_partitioner(num_slices),
2872          dtype=dtypes.float32,
2873          initializer=initializer,
2874          collections=collections,
2875          trainable=False))
2876
2877  if vocabulary_size >= num_hosts:
2878    return var_list
2879
2880  # For padded part, define the dummy variable to be loaded into TPU system.
2881  for idx in range(num_hosts - vocabulary_size):
2882    var_list.append(
2883        variable_scope.get_variable(
2884            'dummy_{}_{}'.format(vocabulary_size + idx, name),
2885            shape=(1, embedding_dimension),
2886            dtype=dtypes.float32,
2887            initializer=initializer,
2888            collections=[ops.GraphKeys.LOCAL_VARIABLES],
2889            trainable=False))
2890
2891  return var_list
2892