1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""Tooling for support TPU embedding in TPUEstimator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.estimator import model_fn as model_fn_lib
24from tensorflow.python.feature_column import feature_column as core_fc
25from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
26from tensorflow.python.tpu import feature_column as tpu_fc
27from tensorflow.python.tpu import tpu_embedding
28from tensorflow.python.tpu.tpu_embedding import AdagradParameters
29from tensorflow.python.tpu.tpu_embedding import AdamParameters
30from tensorflow.python.tpu.tpu_embedding import StochasticGradientDescentParameters
31
32# pylint: disable=protected-access
33_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
34                                 tpu_fc._TPUSharedEmbeddingColumn)
35_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn,
36                             core_fc_lib.EmbeddingColumn,
37                             core_fc._SharedEmbeddingColumn)
38_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn)
39_SUPPORTED_OPTIMIZERS = (AdagradParameters, AdamParameters,
40                         StochasticGradientDescentParameters)
41
42# pylint: enable=protected-access
43
44_TABLE_NAME_PREFIX = 'tbl_'
45_LEN_TABLE_NAME_PREFIX = len(_TABLE_NAME_PREFIX)
46
47
48def _get_table_name_from_embedding_var_name(embedding_var_name):
49  return '{}{}'.format(_TABLE_NAME_PREFIX, embedding_var_name)
50
51
52def _get_embedding_var_name_from_table_name(table_name):
53  return table_name[_LEN_TABLE_NAME_PREFIX:]
54
55
56def _get_embedding_variable_name(scope_name, var_name):
57  return '{}/{}'.format(scope_name, var_name)
58
59
60def _get_slot_variable_names(scope_name, var_name, optimization_parameters):
61  """Return embedding variable names which are consistent with CPU runs."""
62  if isinstance(optimization_parameters, tpu_embedding.AdagradParameters):
63    return tpu_embedding.AdagradSlotVariableName(
64        '{}/{}/Adagrad'.format(scope_name, var_name)
65    )
66  elif isinstance(optimization_parameters, tpu_embedding.AdamParameters):
67    return tpu_embedding.AdamSlotVariableNames(
68        '{}/{}/Adam/m'.format(scope_name, var_name),
69        '{}/{}/Adam/v'.format(scope_name, var_name)
70    )
71  elif isinstance(optimization_parameters,
72                  tpu_embedding.StochasticGradientDescentParameters):
73    return None
74  else:
75    raise ValueError('Support to infer full variable name '
76                     'for optimization_parameter {} has not been added.'
77                     .format(optimization_parameters))
78
79
80def get_full_variable_names(
81    graph, table_to_config_dict, optimization_parameters=None):
82  """Return embedding variable names and slot variables which are consistent with CPU runs."""
83  collection = graph.get_collection_ref(tpu_fc._TPU_FC_TO_SCOPE)  # pylint: disable=protected-access
84  if not collection:
85    raise RuntimeError(
86        'Embedding feature column did not capture any thing. Make sure the '
87        'feature columns passed to TPUEstimator constructor is properly '
88        'used in model_fn.')
89
90  embedding_variable_name_by_table = {}
91  slot_variable_names_by_table = {}
92  for table_name in table_to_config_dict:
93    embedding_var_name = _get_embedding_var_name_from_table_name(table_name)
94    (scope_name, var_name) = collection[0][embedding_var_name]
95    embedding_variable_name_by_table[table_name] = (
96        _get_embedding_variable_name(scope_name, var_name))
97    if optimization_parameters:
98      slot_variable_names_by_table[table_name] = _get_slot_variable_names(
99          scope_name, var_name, optimization_parameters)
100
101  graph.clear_collection(tpu_fc._TPU_FC_TO_SCOPE)  # pylint: disable=protected-access
102  return embedding_variable_name_by_table, slot_variable_names_by_table
103
104
105def get_tpu_embedding_config_from_feature_columns(feature_columns):
106  """Create configs for TPUEmbedding from a list of feature columns.
107
108  This function will place one embedding tensor per table and the return is
109  intended to be used as input to TPUEmbedding.
110
111  Args:
112    feature_columns: a list of supported feature columns.
113
114  Returns:
115    A pair of dicts, the first maps tables to their config, the second maps
116    features to tables.
117  """
118
119  allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn)  # pylint: disable=protected-access
120
121  for column in feature_columns:
122    if not isinstance(column, allowed):
123      raise TypeError(
124          'Unsupported feature column {}. Supported types are {}.'.format(
125              type(column), allowed))
126
127  table_to_config = {}
128  feature_to_table = {}
129  for column in feature_columns:
130    feature_name = column.get_feature_key_name()
131    table_name = _get_table_name_from_embedding_var_name(
132        column.get_embedding_var_name())
133    if feature_name in feature_to_table:
134      raise ValueError(
135          'Feature column {} is used with multiple embeddings and this is '
136          'not supported.'.format(feature_name))
137    feature_to_table[feature_name] = table_name
138    vocabulary_size, dimension = column.get_embedding_table_size()
139    table_to_config[table_name] = tpu_embedding.TableConfig(
140        vocabulary_size=vocabulary_size,
141        dimension=dimension,
142        initializer=column.get_initializer(),
143        combiner=column.get_combiner())
144
145  return table_to_config, feature_to_table
146
147
148class EmbeddingConfigSpec(
149    collections.namedtuple('EmbeddingConfigSpec', [
150        'feature_columns', 'optimization_parameters', 'clipping_limit',
151    ])):
152  """Class to keep track of embedding config specification."""
153
154  def __new__(cls,
155              feature_columns,
156              optimization_parameters,
157              clipping_limit=None):
158    """Creates an EmbeddingConfigSpec instance.
159
160    Args:
161      feature_columns: All `FeatureColumn`s used by model.
162      optimization_parameters: An instance of `AdagradParameters`,
163        `AdamParameters` or `StochasticGradientDescentParameters`. This
164        optimizer will be applied to all embedding variables specified by
165        `feature_columns`.
166      clipping_limit: (Optional) Clipping limit (absolute value).
167
168    Returns:
169      An EmbeddingConfigSpec instance.
170
171    Raises:
172      ValueError: If the feature_columns are not specified.
173      TypeError: If the feature columns are not of ths correct type (one of
174        _SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR
175        _EMBEDDING_COLUMN_CLASSES).
176      ValueError: If `optimization_parameters` is not one of the required types.
177    """
178    if not feature_columns:
179      raise ValueError('`feature_columns` cannot be `None` or empty.')
180
181    # It is unknown at this moment, whether the TPUEstimator is running in CPU
182    # or TPU mode. So allow non-TPU embedding columns also.
183    supported_classes = tuple(
184        list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) +
185        list(_EMBEDDING_COLUMN_CLASSES))
186
187    for column in feature_columns:
188      if not isinstance(column, supported_classes):
189        raise TypeError(
190            'All feature columns must be supported types in {}. Got {}'.format(
191                supported_classes, type(column)))
192
193    if not isinstance(optimization_parameters, _SUPPORTED_OPTIMIZERS):
194      raise ValueError('optimization_parameters must be an instance of type '
195                       '{}. Got {}.'.format(_SUPPORTED_OPTIMIZERS,
196                                            type(optimization_parameters)))
197
198    return super(EmbeddingConfigSpec, cls).__new__(
199        cls,
200        feature_columns=feature_columns,
201        optimization_parameters=optimization_parameters,
202        clipping_limit=clipping_limit)
203
204
205class EmbeddingConfig(object):
206  """This is the internal immutable object for embedding config.
207
208  `_EmbeddingConfig` is responsible to _translate_ user provided
209  `EmbeddingConfigSpec` to internal data structures, mostly constructor
210  arguments of `TPUEmbedding`.
211  """
212
213  def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
214               num_hosts, num_cores, run_config):
215    self._embedding_config_spec = embedding_config_spec
216    self._train_batch_size = train_batch_size
217    self._eval_batch_size = eval_batch_size
218    self._num_hosts = num_hosts
219    self._num_cores = num_cores
220    self._run_config = run_config
221
222    self._table_to_config_dict, self._feature_to_table_dict = (
223        get_tpu_embedding_config_from_feature_columns(
224            embedding_config_spec.feature_columns))
225    self._mode_to_tpu_embedding_dict = {}
226    self.dummy_table_variables = None
227
228  def has_embedding_tables(self):
229    return bool(self._table_to_config_dict)
230
231  def _create_tpu_embedding(self, mode):
232    """Create tpu_embedding.TPUEmbedding based on mode."""
233    if mode == model_fn_lib.ModeKeys.TRAIN:
234      batch_size = self._train_batch_size
235    else:
236      batch_size = self._eval_batch_size
237
238    if mode == model_fn_lib.ModeKeys.TRAIN:
239      tpu_embedding_mode = tpu_embedding.TRAINING
240      optimization_parameters = (
241          self._embedding_config_spec.optimization_parameters)
242    elif (mode == model_fn_lib.ModeKeys.EVAL or
243          mode == model_fn_lib.ModeKeys.PREDICT):
244      tpu_embedding_mode = tpu_embedding.INFERENCE
245      optimization_parameters = None
246    else:
247      raise ValueError('Mode {} is not supported.'.format(mode))
248
249    if self._run_config.cluster:
250      master = self._run_config.cluster.master()
251      cluster_spec = self._run_config.cluster.cluster_spec()
252      cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
253    else:
254      master = (
255          self._run_config.evaluation_master
256          if mode == model_fn_lib.ModeKeys.EVAL else self._run_config.master)
257      cluster_def = None
258    tpu_embedding_ = tpu_embedding.TPUEmbedding(
259        self._table_to_config_dict,
260        self._feature_to_table_dict,
261        batch_size,
262        tpu_embedding_mode,
263        master,
264        optimization_parameters,
265        cluster_def,
266    )
267    return tpu_embedding_
268
269  def get_tpu_embedding(self, mode):
270    if mode not in self._mode_to_tpu_embedding_dict:
271      self._mode_to_tpu_embedding_dict[mode] = (
272          self._create_tpu_embedding(mode))
273    return self._mode_to_tpu_embedding_dict[mode]
274
275
276def split_inputs(ctx, features, labels):
277  """Splits the dense and sparse tensors inside the features and labels."""
278  sparse_features = collections.OrderedDict()
279  if ctx.embedding_config:
280    tpu_embedding_ = ctx.embedding_config.tpu_embedding
281    for feature_key in tpu_embedding_.feature_to_table_dict:
282      sparse_features[feature_key] = features.pop(feature_key)
283
284  return features, labels, sparse_features
285