1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""TPU Feature Column Library."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import math
21
22from tensorflow.python.feature_column import feature_column as fc
23from tensorflow.python.feature_column import feature_column_lib as fc_lib
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import init_ops
26from tensorflow.python.ops import variable_scope
27from tensorflow.python.tpu import tpu
28from tensorflow.python.tpu import tpu_function
29# pylint: disable=protected-access
30
31
32_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope'
33_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn,
34                                  fc._VocabularyFileCategoricalColumn,
35                                  fc._VocabularyListCategoricalColumn,
36                                  fc._WeightedCategoricalColumn,
37                                  fc_lib.IdentityCategoricalColumn,
38                                  fc_lib.VocabularyFileCategoricalColumn,
39                                  fc_lib.VocabularyListCategoricalColumn,
40                                  fc_lib.WeightedCategoricalColumn)
41
42
43def embedding_column(categorical_column,
44                     dimension,
45                     combiner='mean',
46                     initializer=None):
47  """TPU embedding_column for `tf.feature_column.embedding_column`.
48
49  Note that the interface for TPU embedding_column is different from the non-TPU
50  version. The following args available for the non-TPU version are NOT
51  supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
52
53  Args:
54    categorical_column: A categorical_column returned from
55        categorical_column_with_identity,  weighted_categorical_column,
56        categorical_column_with_vocabulary_list or
57        categorical_column_with_vocabulary_file.
58    dimension: An integer specifying dimension of the embedding, must be > 0.
59    combiner: A string specifying how to reduce if there are multiple entries
60      in a single row. For more information, see
61      `tf.feature_column.embedding_column`.
62    initializer: A variable initializer function to be used in embedding
63      variable initialization. If not specified, defaults to
64      `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
65      `1/sqrt(dimension)`.
66
67  Returns:
68    A  _TPUEmbeddingColumn.
69
70  Raises:
71    ValueError: if `dimension` not > 0.
72    ValueError: if `initializer` is specified but not callable.
73  """
74  if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
75    raise TypeError(
76        'categorical_column for tpu '
77        ' embedding_column must be type %s, got %s.' % (' or '.join([
78            cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
79        ]), type(categorical_column)))
80  if (dimension is None) or (dimension < 1):
81    raise ValueError('Invalid dimension {}.'.format(dimension))
82
83  if (initializer is not None) and (not callable(initializer)):
84    raise ValueError('initializer must be callable if specified. '
85                     'Embedding of column_name: {}'.format(
86                         categorical_column.name))
87  if initializer is None:
88    initializer = init_ops.truncated_normal_initializer(
89        mean=0.0, stddev=1 / math.sqrt(dimension))
90
91  embedding_shape = categorical_column._num_buckets, dimension  # pylint: disable=protected-access
92
93  def _creator(weight_collections, scope):
94    embedding_column_layer = fc._EmbeddingColumnLayer(
95        embedding_shape=embedding_shape,
96        initializer=initializer,
97        weight_collections=weight_collections,
98        trainable=True,
99        name='embedding_column_layer')
100    return embedding_column_layer(None, scope=scope)  # pylint: disable=not-callable
101
102  column = _TPUEmbeddingColumn(
103      categorical_column=categorical_column,
104      dimension=dimension,
105      combiner=combiner,
106      layer_creator=_creator,
107      ckpt_to_load_from=None,
108      tensor_name_in_ckpt=None,
109      max_norm=None,
110      trainable=True)
111  # For Embedding column, the initializer is hidden inside the creator Fn, which
112  # is not accessiable later. So, we attach it to a speicial field. Also note
113  # that non-TPU Embedding column and non-TPU shared Embedding column handle the
114  # initializer differently. See shared_embedding_columns for details.
115  column._tpu_initializer = initializer
116  return column
117
118
119def shared_embedding_columns(categorical_columns,
120                             dimension,
121                             combiner='mean',
122                             initializer=None,
123                             shared_embedding_collection_name=None):
124  """List of dense columns that convert from sparse, categorical input."""
125  for categorical_column in categorical_columns:
126    if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
127      raise TypeError(
128          'categorical_column for tpu '
129          ' shared_embedding_columns must be type %s, got %s.' % (' or '.join([
130              cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
131          ]), type(categorical_column)))
132  columns = fc_lib.shared_embedding_columns(
133      categorical_columns,
134      dimension,
135      combiner=combiner,
136      initializer=initializer,
137      shared_embedding_collection_name=shared_embedding_collection_name,
138      ckpt_to_load_from=None,
139      tensor_name_in_ckpt=None,
140      max_norm=None,
141      trainable=True)
142
143  # Use the initializer and shared_embedding_collection_name to create TPU
144  # version
145  initializer = columns[0].initializer
146  shared_embedding_collection_name = columns[0].shared_embedding_collection_name
147  tpu_columns = []
148
149  # Create the state (_SharedEmbeddingColumnLayer) here.
150  for categorical_column in categorical_columns:
151    column = _TPUSharedEmbeddingColumn(
152        categorical_column=categorical_column,
153        dimension=dimension,
154        combiner=combiner,
155        initializer=initializer,
156        shared_embedding_collection_name=shared_embedding_collection_name,
157        ckpt_to_load_from=None,
158        tensor_name_in_ckpt=None,
159        max_norm=None,
160        trainable=True)
161    tpu_columns.append(column)
162
163  return tpu_columns
164
165
166class _TPUBaseEmbeddingColumn(object):
167  """Base class for TPU Embedding Column."""
168
169  def __init__(self, categorical_column):
170    self._tpu_categorical_column = categorical_column
171
172  def get_combiner(self):
173    """Returns the embedding combiner."""
174    raise NotImplementedError('not implemented')
175
176  def get_embedding_table_size(self):
177    """Returns the embedding table size, tuple of vocab size and dimension."""
178    raise NotImplementedError('not implemented')
179
180  def get_feature_key_name(self):
181    """Returns the feature key name in the features dict."""
182    raise NotImplementedError('not impl')
183
184  def get_weight_key_name(self):
185    """Return the key name for weights."""
186    raise NotImplementedError('not impl')
187
188  def get_embedding_var_name(self):
189    """Returns the embedding variable name.
190
191    Feature key name and embedding variable name are usually one-to-one mapping.
192    But for shared embedding columns, it is many-to-one mapping.
193    """
194    raise NotImplementedError('not impl')
195
196  def get_initializer(self):
197    """Returns the initializer."""
198    raise NotImplementedError('not impl')
199
200  def is_categorical_column_weighted(self):
201    """Check if the categorical column of the embedding column is weighted."""
202    raise NotImplementedError('not impl')
203
204
205class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
206  """Core Embedding Column."""
207
208  def __new__(cls,
209              categorical_column,
210              dimension,
211              combiner='mean',
212              layer_creator=None,
213              ckpt_to_load_from=None,
214              tensor_name_in_ckpt=None,
215              max_norm=None,
216              trainable=True):
217    # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
218    # are not supported on TPU. They are solely for matching the signature of
219    # __new__ of parent class fc._EmbeddingColumn.
220    return fc._EmbeddingColumn.__new__(
221        cls,
222        categorical_column,
223        dimension,
224        combiner=combiner,
225        layer_creator=layer_creator,
226        ckpt_to_load_from=ckpt_to_load_from,
227        tensor_name_in_ckpt=tensor_name_in_ckpt,
228        max_norm=max_norm,
229        trainable=trainable)
230
231  def __init__(self,
232               categorical_column,
233               dimension,
234               combiner='mean',
235               layer_creator=None,
236               ckpt_to_load_from=None,
237               tensor_name_in_ckpt=None,
238               max_norm=None,
239               trainable=True):
240    _TPUBaseEmbeddingColumn.__init__(self, categorical_column)
241    self._key = None
242
243  def get_combiner(self):
244    return self.combiner
245
246  def get_embedding_table_size(self):
247    """Returns num_ids and width."""
248    return (self.categorical_column._num_buckets, self.dimension)
249
250  def get_feature_key_name(self):
251    """get_feature_key_name."""
252    if self.is_categorical_column_weighted():
253      return self.categorical_column.categorical_column.name
254    return self.categorical_column.name
255
256  def get_weight_key_name(self):
257    """get_weight_key_name."""
258    if self.is_categorical_column_weighted():
259      return self.categorical_column.weight_feature_key
260    return None
261
262  def get_embedding_var_name(self):
263    """get_embedding_var_name."""
264    return self.categorical_column.name
265
266  def get_initializer(self):
267    return self._tpu_initializer
268
269  def is_categorical_column_weighted(self):
270    """Check if the categorical column of the embedding column is weighted."""
271    if isinstance(
272        self.categorical_column,
273        (
274            fc._WeightedCategoricalColumn,  # pylint: disable=protected-access
275            fc_lib.WeightedCategoricalColumn)):
276      return True
277    return False
278
279  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
280    if tpu.under_tpu_inference_context():
281      def host_computation():
282        return fc._EmbeddingColumn._get_dense_tensor(
283            self, inputs, weight_collections, trainable)
284      return tpu.outside_compilation(host_computation)
285
286    if _is_running_on_cpu():
287      return fc._EmbeddingColumn._get_dense_tensor(
288          self, inputs, weight_collections, trainable)
289
290    # TPU mode
291    # Get the embeddings from the LazyBuilder.
292    tensor = inputs.get(self.get_feature_key_name())
293
294    # Add to collection for _create_tpu_embedding_variables_and_ops
295    _record_variable_scope_and_name(self.get_embedding_var_name(),
296                                    'embedding_weights')
297
298    return tensor
299
300
301class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
302                                fc._SharedEmbeddingColumn):
303  """Core Shared Embedding Column."""
304
305  def __new__(cls,
306              categorical_column,
307              dimension,
308              combiner='mean',
309              initializer=None,
310              shared_embedding_collection_name=None,
311              ckpt_to_load_from=None,
312              tensor_name_in_ckpt=None,
313              max_norm=None,
314              trainable=True):
315    return fc._SharedEmbeddingColumn.__new__(
316        cls,
317        categorical_column,
318        dimension,
319        combiner=combiner,
320        initializer=initializer,
321        shared_embedding_collection_name=shared_embedding_collection_name,
322        ckpt_to_load_from=ckpt_to_load_from,
323        tensor_name_in_ckpt=tensor_name_in_ckpt,
324        max_norm=max_norm,
325        trainable=trainable)
326
327  def __init__(self,
328               categorical_column,
329               dimension,
330               combiner='mean',
331               initializer=None,
332               shared_embedding_collection_name=None,
333               ckpt_to_load_from=None,
334               tensor_name_in_ckpt=None,
335               max_norm=None,
336               trainable=True):
337
338    _TPUBaseEmbeddingColumn.__init__(self, categorical_column)
339    self._key = None
340
341  def get_combiner(self):
342    return self.combiner
343
344  def get_embedding_table_size(self):
345    """Returns num_ids and width."""
346    return (self.categorical_column._num_buckets, self.dimension)
347
348  def get_feature_key_name(self):
349    """get_feature_key_name."""
350    if self.is_categorical_column_weighted():
351      return self.categorical_column.categorical_column.name
352    return self.categorical_column.name
353
354  def get_weight_key_name(self):
355    """get_weight_key_name."""
356    if self.is_categorical_column_weighted():
357      return self.categorical_column.weight_feature_key
358    return None
359
360  def get_embedding_var_name(self):
361    """get_embedding_var_name."""
362    return self.shared_embedding_collection_name
363
364  def get_initializer(self):
365    return self.initializer
366
367  def is_categorical_column_weighted(self):
368    """Check if the categorical column of the embedding column is weighted."""
369    if isinstance(
370        self.categorical_column,
371        (
372            fc._WeightedCategoricalColumn,  # pylint: disable=protected-access
373            fc_lib.WeightedCategoricalColumn)):
374      return True
375    return False
376
377  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
378    if tpu.under_tpu_inference_context():
379      def host_computation():
380        return fc._SharedEmbeddingColumn._get_dense_tensor(
381            self, inputs, weight_collections, trainable)
382      return tpu.outside_compilation(host_computation)
383
384    if _is_running_on_cpu():
385      return fc._SharedEmbeddingColumn._get_dense_tensor(
386          self, inputs, weight_collections, trainable)
387
388    # TPU mode
389    # Get the embeddings from the LazyBuilder.
390    tensor = inputs.get(self.get_feature_key_name())
391
392    # Add to collection for _create_tpu_embedding_variables_and_ops
393    _record_variable_scope_and_name(
394        self.get_embedding_var_name(),
395        'embedding_weights',
396        is_shared_embedding=True)
397    return tensor
398
399
400def _record_variable_scope_and_name(embedding_var_name,
401                                    embedding_var_name_in_fc,
402                                    is_shared_embedding=False):
403  """Add embedding variable name and scope to collection."""
404  g = ops.get_default_graph()
405  collection = g.get_collection_ref(_TPU_FC_TO_SCOPE)
406  if not collection:
407    collection.append({})
408
409  var_def_dict = collection[0]
410
411  captured_scope = variable_scope.get_variable_scope()
412  captured_scope_name = captured_scope.name
413
414  if embedding_var_name in var_def_dict:
415    if (var_def_dict[embedding_var_name][0] != captured_scope_name
416        and not is_shared_embedding):
417      raise ValueError(
418          'For embedding var name {}, the variable scope name is different, '
419          'got {}; expected {}'.format(embedding_var_name,
420                                       captured_scope_name,
421                                       var_def_dict[embedding_var_name][0]))
422    if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc:
423      raise ValueError(
424          'For embedding var name {}, the embedding name is different, '
425          'got {}; expected {}'.format(embedding_var_name,
426                                       embedding_var_name_in_fc,
427                                       var_def_dict[embedding_var_name][1]))
428  else:
429    var_def_dict[embedding_var_name] = (captured_scope_name,
430                                        embedding_var_name_in_fc)
431
432
433def _is_running_on_cpu():
434  """Returns True if the current context is CPU model."""
435  return tpu_function.get_tpu_context().number_of_shards is None
436