1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Companion classes for mid level API for TPU Embeddings in TF2."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from __future__ import unicode_literals
21
22import abc
23import math
24import typing
25from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
26
27from absl import logging
28import six
29
30from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
31from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
32from tensorflow.python.distribute import sharded_variable
33from tensorflow.python.framework import ops
34from tensorflow.python.ops import init_ops_v2
35from tensorflow.python.ops import variables as tf_variables
36from tensorflow.python.tpu.ops import tpu_ops
37from tensorflow.python.types import core
38from tensorflow.python.util.tf_export import tf_export
39
40
41TableVariable = TypeVar("TableVariable", sharded_variable.ShardedVariable,
42                        tf_variables.Variable)
43SlotVarCreationFnType = Callable[
44    [TableVariable, List[Text], List[init_ops_v2.Initializer]],
45    Dict[Text, TableVariable]]
46ClipValueType = Union[Tuple[float, float], float]
47
48
49@six.add_metaclass(abc.ABCMeta)
50class _Optimizer(object):
51  """Base class for all optimizers, with common parameters."""
52
53  def __init__(
54      self,
55      learning_rate: Union[float, Callable[[], float]],
56      use_gradient_accumulation: bool,
57      clip_weight_min: Optional[float],
58      clip_weight_max: Optional[float],
59      weight_decay_factor: Optional[float],
60      multiply_weight_decay_factor_by_learning_rate: bool,
61      clipvalue: Optional[ClipValueType] = None,
62      slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None):
63    self.learning_rate = learning_rate
64    self.use_gradient_accumulation = use_gradient_accumulation
65    self.clip_weight_min = clip_weight_min
66    self.clip_weight_max = clip_weight_max
67    if not use_gradient_accumulation and clipvalue is not None:
68      raise ValueError("Received non-None gradient clipping limit {} but "
69                       "use_gradient_accumulation is not set to True.".format(
70                           clipvalue))
71    if clipvalue is None:
72      clipvalue = (None, None)
73    elif not isinstance(clipvalue, tuple):
74      clipvalue = (-1. * clipvalue, clipvalue)
75    self.clip_gradient_min, self.clip_gradient_max = clipvalue
76
77    self.weight_decay_factor = weight_decay_factor
78    self.multiply_weight_decay_factor_by_learning_rate = (
79        multiply_weight_decay_factor_by_learning_rate)
80
81    if (slot_variable_creation_fn is not None and
82        not callable(slot_variable_creation_fn)):
83      raise ValueError("slot_variable_creation_fn must be either None or a "
84                       "callable.")
85    self.slot_variable_creation_fn = slot_variable_creation_fn
86
87  @abc.abstractmethod
88  def _slot_names(self) -> List[Text]:
89    """Returns the name of all the slot variables.
90
91    This does not include the 'parameters' variable and these names must match
92    the names of the slots variables as used in the corresponding
93    `tpu_ops.load_tpu_embedding_*` ops.
94    """
95    raise NotImplementedError
96
97  @abc.abstractmethod
98  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
99    """Returns initializers for slot variables.
100
101    This returns a parallel list to self._slot_names().
102    """
103    raise NotImplementedError
104
105  def _set_optimization_parameters(
106      self, parameters: optimization_parameters_pb2.OptimizationParameters):
107    """Sets the optimizer fields in the OptimizationParameters."""
108    if self.use_gradient_accumulation:
109      parameters.gradient_accumulation_status = (
110          optimization_parameters_pb2.GradientAccumulationStatus.ENABLED)
111    else:
112      parameters.gradient_accumulation_status = (
113          optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
114
115    if self.clip_weight_min is not None:
116      parameters.clipping_limits.lower.value = self.clip_weight_min
117
118    if self.clip_weight_max is not None:
119      parameters.clipping_limits.upper.value = self.clip_weight_max
120
121    if self.clip_gradient_min is not None:
122      parameters.gradient_clipping_limits.lower.value = self.clip_gradient_min
123
124    if self.clip_gradient_max is not None:
125      parameters.gradient_clipping_limits.upper.value = self.clip_gradient_max
126
127    if self.weight_decay_factor:
128      parameters.weight_decay_factor = self.weight_decay_factor
129      if self.multiply_weight_decay_factor_by_learning_rate:
130        parameters.multiply_weight_decay_factor_by_learning_rate = True
131
132  @abc.abstractmethod
133  def _load(self) -> Callable[..., ops.Operation]:
134    """Returns the load function for the optimizer."""
135    raise NotImplementedError
136
137  @abc.abstractmethod
138  def _retrieve(self) -> Callable[..., core.Tensor]:
139    """Returns the retrieve function for the optimizer."""
140    raise NotImplementedError
141
142  def _create_slots(
143      self, table: "TableConfig",
144      variable_creator: Callable[[Text, init_ops_v2.Initializer],
145                                 tf_variables.Variable]
146  ) -> Dict[Text, tf_variables.Variable]:
147    """Creates slot variables for table.
148
149    Args:
150      table: The table variable to create slots for.
151      variable_creator: A function which creates variables. Takes parameters
152        'name', 'initializer'.
153
154    Returns:
155      A dict of variables, keyed by self._slot_names().
156    """
157    if self.slot_variable_creation_fn is not None:
158      return self.slot_variable_creation_fn(table, self._slot_names(),
159                                            self._slot_initializers())
160    else:
161      slots = {}
162      for slot, initializer in zip(self._slot_names(),
163                                   self._slot_initializers()):
164        slots[slot] = variable_creator(slot, initializer)
165      return slots
166
167
168@tf_export("tpu.experimental.embedding.SGD")
169class SGD(_Optimizer):
170  """Optimization parameters for stochastic gradient descent for TPU embeddings.
171
172  Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
173  argument to set the global optimizer and its parameters:
174
175  ```
176  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
177      ...
178      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
179  ```
180
181  This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
182  optimizer parameter to set a table specific optimizer. This will override the
183  optimizer and parameters for global embedding optimizer defined above:
184
185  ```
186  table_one = tf.tpu.experimental.embedding.TableConfig(
187      vocabulary_size=...,
188      dim=...,
189      optimizer=tf.tpu.experimental.embedding.SGD(0.2))
190  table_two = tf.tpu.experimental.embedding.TableConfig(
191      vocabulary_size=...,
192      dim=...)
193
194  feature_config = (
195      tf.tpu.experimental.embedding.FeatureConfig(
196          table=table_one),
197      tf.tpu.experimental.embedding.FeatureConfig(
198          table=table_two))
199
200  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
201      feature_config=feature_config,
202      batch_size=...
203      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
204  ```
205
206  In the above example, the first feature will be looked up in a table that has
207  a learning rate of 0.2 while the second feature will be looked up in a table
208  that has a learning rate of 0.1.
209
210  See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
211  complete description of these parameters and their impacts on the optimizer
212  algorithm.
213  """
214
215  def __init__(self,
216               learning_rate: Union[float, Callable[[], float]] = 0.01,
217               clip_weight_min: Optional[float] = None,
218               clip_weight_max: Optional[float] = None,
219               weight_decay_factor: Optional[float] = None,
220               multiply_weight_decay_factor_by_learning_rate: bool = None,
221               clipvalue: Optional[ClipValueType] = None):
222    """Optimization parameters for stochastic gradient descent.
223
224    Args:
225      learning_rate: The learning rate. It should be a floating point value or a
226        callable taking no arguments for a dynamic learning rate.
227      clip_weight_min: the minimum value to clip by; None means -infinity.
228      clip_weight_max: the maximum value to clip by; None means +infinity.
229      weight_decay_factor: amount of weight decay to apply; None means that the
230        weights are not decayed. Weights are decayed by multiplying the weight
231        by this factor each step.
232      multiply_weight_decay_factor_by_learning_rate: if true,
233        `weight_decay_factor` is multiplied by the current learning rate.
234      clipvalue: Controls clipping of the gradient. Set to either a single
235        positive scalar value to get clipping or a tiple of scalar values (min,
236        max) to set a separate maximum or minimum. If one of the two entries is
237        None, then there will be no clipping that direction. Note if this is
238        set, you may see a decrease in performance as  gradient accumulation
239        will be enabled (it is normally off for SGD as it has no affect on
240        accuracy). See
241        'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for more
242        information on gradient accumulation and its impact on tpu embeddings.
243    """
244    use_gradient_accumulation = clipvalue is not None
245
246    super(SGD, self).__init__(
247        learning_rate, use_gradient_accumulation, clip_weight_min,
248        clip_weight_max, weight_decay_factor,
249        multiply_weight_decay_factor_by_learning_rate, clipvalue)
250
251  def _slot_names(self) -> List[Text]:
252    return []
253
254  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
255    return []
256
257  def _set_optimization_parameters(
258      self, parameters: optimization_parameters_pb2.OptimizationParameters):
259    super(SGD, self)._set_optimization_parameters(parameters)
260    parameters.stochastic_gradient_descent.SetInParent()
261
262  def _load(self) -> Callable[..., ops.Operation]:
263    return tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters
264
265  def _retrieve(self) -> Callable[..., core.Tensor]:
266    return tpu_ops.retrieve_tpu_embedding_stochastic_gradient_descent_parameters
267
268
269@tf_export("tpu.experimental.embedding.Adagrad")
270class Adagrad(_Optimizer):
271  """Optimization parameters for Adagrad with TPU embeddings.
272
273  Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
274  argument to set the global optimizer and its parameters:
275
276  ```python
277  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
278      ...
279      optimizer=tf.tpu.experimental.embedding.Adagrad(0.1))
280  ```
281
282  This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
283  optimizer parameter to set a table specific optimizer. This will override the
284  optimizer and parameters for global embedding optimizer defined above:
285
286  ```python
287  table_one = tf.tpu.experimental.embedding.TableConfig(
288      vocabulary_size=...,
289      dim=...,
290      optimizer=tf.tpu.experimental.embedding.Adagrad(0.2))
291  table_two = tf.tpu.experimental.embedding.TableConfig(
292      vocabulary_size=...,
293      dim=...)
294
295  feature_config = (
296      tf.tpu.experimental.embedding.FeatureConfig(
297          table=table_one),
298      tf.tpu.experimental.embedding.FeatureConfig(
299          table=table_two))
300
301  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
302      feature_config=feature_config,
303      batch_size=...
304      optimizer=tf.tpu.experimental.embedding.Adagrad(0.1))
305  ```
306
307  In the above example, the first feature will be looked up in a table that has
308  a learning rate of 0.2 while the second feature will be looked up in a table
309  that has a learning rate of 0.1.
310
311  See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
312  complete description of these parameters and their impacts on the optimizer
313  algorithm.
314  """
315
316  def __init__(
317      self,
318      learning_rate: float = 0.001,
319      initial_accumulator_value: float = 0.1,
320      use_gradient_accumulation: bool = True,
321      clip_weight_min: Optional[float] = None,
322      clip_weight_max: Optional[float] = None,
323      weight_decay_factor: Optional[float] = None,
324      multiply_weight_decay_factor_by_learning_rate: bool = None,
325      slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
326      clipvalue: Optional[ClipValueType] = None):
327    """Optimization parameters for Adagrad.
328
329    Args:
330      learning_rate: The learning rate. It should be a floating point value or a
331        callable taking no arguments for a dynamic learning rate.
332      initial_accumulator_value: initial accumulator for Adagrad.
333      use_gradient_accumulation: setting this to `False` makes embedding
334        gradients calculation less accurate but faster.
335      clip_weight_min: the minimum value to clip by; None means -infinity.
336      clip_weight_max: the maximum value to clip by; None means +infinity.
337      weight_decay_factor: amount of weight decay to apply; None means that the
338        weights are not decayed.
339      multiply_weight_decay_factor_by_learning_rate: if true,
340        `weight_decay_factor` is multiplied by the current learning rate.
341      slot_variable_creation_fn: If you wish do directly control the creation of
342        the slot variables, set this to a callable taking three parameters: a
343          table variable, a list of slot names to create for it, and a list of
344          initializers. This function should return a dict with the slot names
345          as keys and the created variables as values with types matching the
346          table variable. When set to None (the default), uses the built-in
347          variable creation.
348      clipvalue: Controls clipping of the gradient. Set to either a single
349        positive scalar value to get clipping or a tuple of scalar values (min,
350        max) to set a separate maximum or minimum. If one of the two entries is
351        None, then there will be no clipping that direction.
352    """
353    super(Adagrad, self).__init__(
354        learning_rate, use_gradient_accumulation, clip_weight_min,
355        clip_weight_max, weight_decay_factor,
356        multiply_weight_decay_factor_by_learning_rate, clipvalue,
357        slot_variable_creation_fn)
358    if initial_accumulator_value <= 0:
359      raise ValueError("Adagrad initial_accumulator_value must be positive")
360    self.initial_accumulator_value = initial_accumulator_value
361
362  def _slot_names(self) -> List[Text]:
363    return ["accumulators"]
364
365  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
366    return [init_ops_v2.Constant(self.initial_accumulator_value)]
367
368  def _set_optimization_parameters(
369      self, parameters: optimization_parameters_pb2.OptimizationParameters):
370    super(Adagrad, self)._set_optimization_parameters(parameters)
371    parameters.adagrad.SetInParent()
372
373  def _load(self) -> Callable[..., ops.Operation]:
374    return tpu_ops.load_tpu_embedding_adagrad_parameters
375
376  def _retrieve(self) -> Callable[..., core.Tensor]:
377    return tpu_ops.retrieve_tpu_embedding_adagrad_parameters
378
379
380@tf_export("tpu.experimental.embedding.Adam")
381class Adam(_Optimizer):
382  """Optimization parameters for Adam with TPU embeddings.
383
384  Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
385  argument to set the global optimizer and its parameters:
386
387  NOTE: By default this optimizer is lazy, i.e. it will not apply the gradient
388  update of zero to rows that were not looked up. You can change this behavior
389  by setting `lazy_adam` to `False`.
390
391  ```python
392  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
393      ...
394      optimizer=tf.tpu.experimental.embedding.Adam(0.1))
395  ```
396
397  This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
398  optimizer parameter to set a table specific optimizer. This will override the
399  optimizer and parameters for global embedding optimizer defined above:
400
401  ```python
402  table_one = tf.tpu.experimental.embedding.TableConfig(
403      vocabulary_size=...,
404      dim=...,
405      optimizer=tf.tpu.experimental.embedding.Adam(0.2))
406  table_two = tf.tpu.experimental.embedding.TableConfig(
407      vocabulary_size=...,
408      dim=...)
409
410  feature_config = (
411      tf.tpu.experimental.embedding.FeatureConfig(
412          table=table_one),
413      tf.tpu.experimental.embedding.FeatureConfig(
414          table=table_two))
415
416  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
417      feature_config=feature_config,
418      batch_size=...
419      optimizer=tf.tpu.experimental.embedding.Adam(0.1))
420  ```
421
422  In the above example, the first feature will be looked up in a table that has
423  a learning rate of 0.2 while the second feature will be looked up in a table
424  that has a learning rate of 0.1.
425
426  See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
427  complete description of these parameters and their impacts on the optimizer
428  algorithm.
429  """
430
431  def __init__(
432      self,
433      learning_rate: Union[float, Callable[[], float]] = 0.001,
434      beta_1: float = 0.9,
435      beta_2: float = 0.999,
436      epsilon: float = 1e-07,
437      lazy_adam: bool = True,
438      sum_inside_sqrt: bool = True,
439      use_gradient_accumulation: bool = True,
440      clip_weight_min: Optional[float] = None,
441      clip_weight_max: Optional[float] = None,
442      weight_decay_factor: Optional[float] = None,
443      multiply_weight_decay_factor_by_learning_rate: bool = None,
444      slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
445      clipvalue: Optional[ClipValueType] = None):
446    """Optimization parameters for Adam.
447
448    See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
449    complete description of these parameters and their impacts on the optimizer
450    algorithm.
451
452    Args:
453      learning_rate: The learning rate. It should be a floating point value or a
454        callable taking no arguments for a dynamic learning rate.
455      beta_1: A float value. The exponential decay rate for the 1st moment
456        estimates.
457      beta_2: A float value. The exponential decay rate for the 2nd moment
458        estimates.
459      epsilon: A small constant for numerical stability.
460      lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster.
461      sum_inside_sqrt: When this is true, the Adam update formula is changed
462        from `m / (sqrt(v) + epsilon)` to `m / sqrt(v + epsilon**2)`. This
463        option improves the performance of TPU training and is not expected to
464        harm model quality.
465      use_gradient_accumulation: Setting this to `False` makes embedding
466        gradients calculation less accurate but faster.
467      clip_weight_min: the minimum value to clip by; None means -infinity.
468      clip_weight_max: the maximum value to clip by; None means +infinity.
469      weight_decay_factor: amount of weight decay to apply; None means that the
470        weights are not decayed.
471      multiply_weight_decay_factor_by_learning_rate: if true,
472        `weight_decay_factor` is multiplied by the current learning rate.
473      slot_variable_creation_fn: If you wish do directly control the creation of
474        the slot variables, set this to a callable taking three parameters: a
475          table variable, a list of slot names to create for it, and a list of
476          initializers. This function should return a dict with the slot names
477          as keys and the created variables as values with types matching the
478          table variable. When set to None (the default), uses the built-in
479          variable creation.
480      clipvalue: Controls clipping of the gradient. Set to either a single
481        positive scalar value to get clipping or a tiple of scalar values (min,
482        max) to set a separate maximum or minimum. If one of the two entries is
483        None, then there will be no clipping that direction.
484    """
485    super(Adam, self).__init__(
486        learning_rate, use_gradient_accumulation, clip_weight_min,
487        clip_weight_max, weight_decay_factor,
488        multiply_weight_decay_factor_by_learning_rate, clipvalue,
489        slot_variable_creation_fn)
490    if beta_1 < 0. or beta_1 >= 1.:
491      raise ValueError("beta1 must be in the range [0, 1), but received {}."
492                       .format(beta_1))
493    if beta_2 < 0. or beta_2 >= 1.:
494      raise ValueError("beta2 must be in the range [0, 1), but received {}."
495                       .format(beta_2))
496    if epsilon <= 0.:
497      raise ValueError("epsilon must be positive; got {}.".format(epsilon))
498    if not use_gradient_accumulation and not lazy_adam:
499      raise ValueError(
500          "When disabling Lazy Adam, gradient accumulation must be used.")
501
502    self.beta_1 = beta_1
503    self.beta_2 = beta_2
504    self.epsilon = epsilon
505    self.lazy_adam = lazy_adam
506    self.sum_inside_sqrt = sum_inside_sqrt
507
508  def _slot_names(self) -> List[Text]:
509    return ["momenta", "velocities"]
510
511  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
512    return [init_ops_v2.Constant(), init_ops_v2.Constant()]
513
514  def _set_optimization_parameters(
515      self, parameters: optimization_parameters_pb2.OptimizationParameters):
516    super(Adam, self)._set_optimization_parameters(parameters)
517    parameters.adam.beta1 = self.beta_1
518    parameters.adam.beta2 = self.beta_2
519    parameters.adam.epsilon = self.epsilon
520    parameters.adam.use_non_lazy_adam = not self.lazy_adam
521    parameters.adam.use_sum_inside_sqrt = self.sum_inside_sqrt
522
523  def _load(self) -> Callable[..., ops.Operation]:
524    return tpu_ops.load_tpu_embedding_adam_parameters
525
526  def _retrieve(self) -> Callable[..., core.Tensor]:
527    return tpu_ops.retrieve_tpu_embedding_adam_parameters
528
529
530@tf_export("tpu.experimental.embedding.TableConfig")
531class TableConfig(object):
532  """Configuration data for one embedding table.
533
534  This class holds the configuration data for a single embedding table. It is
535  used as the `table` parameter of a
536  `tf.tpu.experimental.embedding.FeatureConfig`. Multiple
537  `tf.tpu.experimental.embedding.FeatureConfig` objects can use the same
538  `tf.tpu.experimental.embedding.TableConfig` object. In this case a shared
539  table will be created for those feature lookups.
540
541  ```python
542  table_config_one = tf.tpu.experimental.embedding.TableConfig(
543      vocabulary_size=...,
544      dim=...)
545  table_config_two = tf.tpu.experimental.embedding.TableConfig(
546      vocabulary_size=...,
547      dim=...)
548  feature_config = {
549      'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
550          table=table_config_one),
551      'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
552          table=table_config_one),
553      'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
554          table=table_config_two)}
555  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
556      feature_config=feature_config,
557      batch_size=...
558      optimizer=tf.tpu.experimental.embedding.Adam(0.1))
559  ```
560
561  The above configuration has 2 tables, and three features. The first two
562  features will be looked up in the first table and the third feature will be
563  looked up in the second table.
564
565  """
566
567  def __init__(self,
568               vocabulary_size: int,
569               dim: int,
570               initializer: Optional[Callable[[Any], None]],
571               optimizer: Optional[_Optimizer] = None,
572               combiner: Text = "mean",
573               name: Optional[Text] = None):
574    """Embedding table configuration.
575
576    Args:
577      vocabulary_size: Size of the table's vocabulary (number of rows).
578      dim: The embedding dimension (width) of the table.
579      initializer: A callable initializer taking one parameter, the shape of the
580        variable that will be initialized. Will be called once per task, to
581        initialize that task's shard of the embedding table. If not specified,
582        defaults to `truncated_normal_initializer` with mean `0.0` and standard
583        deviation `1/sqrt(dim)`.
584      optimizer: An optional instance of an optimizer parameters class, instance
585        of one of `tf.tpu.experimental.embedding.SGD`,
586        `tf.tpu.experimental.embedding.Adagrad` or
587        `tf.tpu.experimental.embedding.Adam`. It set will override the global
588        optimizer passed to `tf.tpu.experimental.embedding.TPUEmbedding`.
589      combiner: A string specifying how to reduce if there are multiple entries
590        in a single row. Currently 'mean', 'sqrtn', 'sum' are supported, with
591        'mean' the default. 'sqrtn' often achieves good accuracy, in particular
592        with bag-of-words columns. For more information, see
593        `tf.nn.embedding_lookup_sparse`.
594      name: An optional string used to name the table. Useful for debugging.
595
596    Returns:
597      `TableConfig`.
598
599    Raises:
600      ValueError: if `vocabulary_size` is not a positive integer.
601      ValueError: if `dim` is not a positive integer.
602      ValueError: if `initializer` is specified and is not callable.
603      ValueError: if `combiner` is not supported.
604    """
605    if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
606      raise ValueError("Invalid vocabulary_size {}.".format(vocabulary_size))
607
608    if not isinstance(dim, int) or dim < 1:
609      raise ValueError("Invalid dim {}.".format(dim))
610
611    if (initializer is not None) and (not callable(initializer)):
612      raise ValueError("initializer must be callable if specified.")
613    if initializer is None:
614      initializer = init_ops_v2.TruncatedNormal(mean=0.0,
615                                                stddev=1/math.sqrt(dim))
616
617    if combiner not in ("mean", "sum", "sqrtn"):
618      raise ValueError("Invalid combiner {}".format(combiner))
619
620    self.vocabulary_size = vocabulary_size
621    self.dim = dim
622    self.initializer = initializer
623    self.optimizer = optimizer
624    self.combiner = combiner
625    self.name = name
626
627  def __repr__(self):
628    # If using the default initializer, just print "None" for clarity.
629    initializer = self.initializer
630
631    if isinstance(initializer, init_ops_v2.TruncatedNormal):
632      # PY2 type checking can't infer type of initializer even after if.
633      initializer = typing.cast(init_ops_v2.TruncatedNormal, initializer)
634      if (initializer.mean == 0.0
635          and math.isclose(initializer.stddev, 1/math.sqrt(self.dim))):  # pytype: disable=module-attr (math.isclose not in PY2)
636        initializer = None
637
638    return (
639        "TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, "
640        "initializer={initializer!r}, optimizer={optimizer!r}, "
641        "combiner={combiner!r}, name={name!r})".format(
642            vocabulary_size=self.vocabulary_size,
643            dim=self.dim,
644            initializer=initializer,
645            optimizer=self.optimizer,
646            combiner=self.combiner,
647            name=self.name,)
648    )
649
650
651@tf_export("tpu.experimental.embedding.FeatureConfig")
652class FeatureConfig(object):
653  """Configuration data for one embedding feature.
654
655  This class holds the configuration data for a single embedding feature. The
656  main use is to assign features to `tf.tpu.experimental.embedding.TableConfig`s
657  via the table parameter:
658
659  ```python
660  table_config_one = tf.tpu.experimental.embedding.TableConfig(
661      vocabulary_size=...,
662      dim=...)
663  table_config_two = tf.tpu.experimental.embedding.TableConfig(
664      vocabulary_size=...,
665      dim=...)
666  feature_config = {
667      'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
668          table=table_config_one),
669      'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
670          table=table_config_one),
671      'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
672          table=table_config_two)}
673  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
674      feature_config=feature_config,
675      batch_size=...
676      optimizer=tf.tpu.experimental.embedding.Adam(0.1))
677  ```
678
679  The above configuration has 2 tables, and three features. The first two
680  features will be looked up in the first table and the third feature will be
681  looked up in the second table.
682
683  When feeding features into `embedding.enqueue` they can be `tf.Tensor`s,
684  `tf.SparseTensor`s or `tf.RaggedTensor`s. When the argument
685  `max_sequence_length` is 0, the default, you should expect a output of
686  `embedding.dequeue` for this feature of shape `(batch_size, dim)`. If
687  `max_sequence_length` is greater than 0, the feature is embedded as a sequence
688  and padded up to the given length. The shape of the output for this feature
689  will be `(batch_size, max_sequence_length, dim)`.
690  """
691
692  def __init__(self,
693               table: TableConfig,
694               max_sequence_length: int = 0,
695               name: Optional[Text] = None):
696    """Feature configuration.
697
698    Args:
699      table: An instance of `tf.tpu.experimental.embedding.TableConfig`,
700        describing the table in which this feature should be looked up.
701      max_sequence_length: If positive, the feature is a sequence feature with
702        the corresponding maximum sequence length. If the sequence is longer
703        than this, it will be truncated. If 0, the feature is not a sequence
704        feature.
705      name: An optional name for the feature, useful for debugging.
706
707    Returns:
708      `FeatureConfig`.
709
710    Raises:
711      ValueError: if `table` is not an instance of
712        `tf.tpu.experimental.embedding.TableConfig`.
713      ValueError: if `max_sequence_length` not an integer or is negative.
714    """
715    if not isinstance(table, TableConfig):
716      raise ValueError("table is type {}, expected "
717                       "`tf.tpu.experimental.embedding.TableConfig`".format(
718                           type(table)))
719
720    if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
721      raise ValueError("Invalid max_sequence_length {}.".format(
722          max_sequence_length))
723
724    self.table = table
725    self.max_sequence_length = max_sequence_length
726    self.name = name
727
728  def __repr__(self):
729    return (
730        "FeatureConfig(table={table!r}, "
731        "max_sequence_length={max_sequence_length!r}, name={name!r})"
732        .format(
733            table=self.table,
734            max_sequence_length=self.max_sequence_length,
735            name=self.name)
736    )
737
738
739def log_tpu_embedding_configuration(
740    config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None:
741  """Logs a TPUEmbeddingConfiguration proto across multiple statements.
742
743  Args:
744    config: TPUEmbeddingConfiguration proto to log.  Necessary because
745      logging.info has a maximum length to each log statement, which
746      particularly large configs can exceed.
747  """
748  logging.info("Beginning log of TPUEmbeddingConfiguration.")
749  for line in str(config).splitlines():
750    logging.info(line)
751  logging.info("Done with log of TPUEmbeddingConfiguration.")
752