1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction.
16
17FeatureColumns provide a high level abstraction for ingesting and representing
18features. FeatureColumns are also the primary way of encoding features for
19canned `tf.estimator.Estimator`s.
20
21When using FeatureColumns with `Estimators`, the type of feature column you
22should choose depends on (1) the feature type and (2) the model type.
23
241. Feature type:
25
26  * Continuous features can be represented by `numeric_column`.
27  * Categorical features can be represented by any `categorical_column_with_*`
28  column:
29    - `categorical_column_with_vocabulary_list`
30    - `categorical_column_with_vocabulary_file`
31    - `categorical_column_with_hash_bucket`
32    - `categorical_column_with_identity`
33    - `weighted_categorical_column`
34
352. Model type:
36
37  * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
38
39    Continuous features can be directly fed into deep neural network models.
40
41      age_column = numeric_column("age")
42
43    To feed sparse features into DNN models, wrap the column with
44    `embedding_column` or `indicator_column`. `indicator_column` is recommended
45    for features with only a few possible values. For features with many
46    possible values, to reduce the size of your model, `embedding_column` is
47    recommended.
48
49      embedded_dept_column = embedding_column(
50          categorical_column_with_vocabulary_list(
51              "department", ["math", "philosophy", ...]), dimension=10)
52
53  * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
54
55    Sparse features can be fed directly into linear models. They behave like an
56    indicator column but with an efficient implementation.
57
58      dept_column = categorical_column_with_vocabulary_list("department",
59          ["math", "philosophy", "english"])
60
61    It is recommended that continuous features be bucketized before being
62    fed into linear models.
63
64      bucketized_age_column = bucketized_column(
65          source_column=age_column,
66          boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
67
68    Sparse features can be crossed (also known as conjuncted or combined) in
69    order to form non-linearities, and then fed into linear models.
70
71      cross_dept_age_column = crossed_column(
72          columns=["department", bucketized_age_column],
73          hash_bucket_size=1000)
74
75Example of building canned `Estimator`s using FeatureColumns:
76
77  ```python
78  # Define features and transformations
79  deep_feature_columns = [age_column, embedded_dept_column]
80  wide_feature_columns = [dept_column, bucketized_age_column,
81      cross_dept_age_column]
82
83  # Build deep model
84  estimator = DNNClassifier(
85      feature_columns=deep_feature_columns,
86      hidden_units=[500, 250, 50])
87  estimator.train(...)
88
89  # Or build a wide model
90  estimator = LinearClassifier(
91      feature_columns=wide_feature_columns)
92  estimator.train(...)
93
94  # Or build a wide and deep model!
95  estimator = DNNLinearCombinedClassifier(
96      linear_feature_columns=wide_feature_columns,
97      dnn_feature_columns=deep_feature_columns,
98      dnn_hidden_units=[500, 250, 50])
99  estimator.train(...)
100  ```
101
102
103FeatureColumns can also be transformed into a generic input layer for
104custom models using `input_layer`.
105
106Example of building model using FeatureColumns, this can be used in a
107`model_fn` which is given to the {tf.estimator.Estimator}:
108
109  ```python
110  # Building model via layers
111
112  deep_feature_columns = [age_column, embedded_dept_column]
113  columns_to_tensor = parse_feature_columns_from_examples(
114      serialized=my_data,
115      feature_columns=deep_feature_columns)
116  first_layer = input_layer(
117      features=columns_to_tensor,
118      feature_columns=deep_feature_columns)
119  second_layer = fully_connected(first_layer, ...)
120  ```
121
122NOTE: Functions prefixed with "_" indicate experimental or private parts of
123the API subject to change, and should not be relied upon!
124
125NOTE: The new feature columns are being developed in feature_column_v2.py and
126are a somewhat duplicate of the code here. Please make sure to update logic
127in both places.
128"""
129
130from __future__ import absolute_import
131from __future__ import division
132from __future__ import print_function
133
134import abc
135import collections
136import math
137
138import numpy as np
139import six
140
141from tensorflow.python.eager import context
142from tensorflow.python.feature_column import utils as fc_utils
143from tensorflow.python.framework import dtypes
144from tensorflow.python.framework import ops
145from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
146from tensorflow.python.framework import tensor_shape
147from tensorflow.python.layers import base
148from tensorflow.python.ops import array_ops
149from tensorflow.python.ops import check_ops
150from tensorflow.python.ops import control_flow_ops
151from tensorflow.python.ops import embedding_ops
152from tensorflow.python.ops import init_ops
153from tensorflow.python.ops import lookup_ops
154from tensorflow.python.ops import math_ops
155from tensorflow.python.ops import nn_ops
156from tensorflow.python.ops import parsing_ops
157from tensorflow.python.ops import resource_variable_ops
158from tensorflow.python.ops import sparse_ops
159from tensorflow.python.ops import string_ops
160from tensorflow.python.ops import template
161from tensorflow.python.ops import variable_scope
162from tensorflow.python.ops import variables
163from tensorflow.python.platform import gfile
164from tensorflow.python.platform import tf_logging as logging
165from tensorflow.python.training import checkpoint_utils
166from tensorflow.python.util import nest
167from tensorflow.python.util.compat import collections_abc
168from tensorflow.python.util.tf_export import tf_export
169
170
171def _internal_input_layer(features,
172                          feature_columns,
173                          weight_collections=None,
174                          trainable=True,
175                          cols_to_vars=None,
176                          scope=None,
177                          cols_to_output_tensors=None,
178                          from_template=False):
179  """See input_layer. `scope` is a name or variable scope to use."""
180
181  feature_columns = _normalize_feature_columns(feature_columns)
182  for column in feature_columns:
183    if not isinstance(column, _DenseColumn):
184      raise ValueError(
185          'Items of feature_columns must be a _DenseColumn. '
186          'You can wrap a categorical column with an '
187          'embedding_column or indicator_column. Given: {}'.format(column))
188  weight_collections = list(weight_collections or [])
189  if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
190    weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
191  if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
192    weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
193
194  def _get_logits():  # pylint: disable=missing-docstring
195    builder = _LazyBuilder(features)
196    output_tensors = []
197    ordered_columns = []
198    for column in sorted(feature_columns, key=lambda x: x.name):
199      ordered_columns.append(column)
200      with variable_scope.variable_scope(
201          None, default_name=column._var_scope_name):  # pylint: disable=protected-access
202        tensor = column._get_dense_tensor(  # pylint: disable=protected-access
203            builder,
204            weight_collections=weight_collections,
205            trainable=trainable)
206        num_elements = column._variable_shape.num_elements()  # pylint: disable=protected-access
207        batch_size = array_ops.shape(tensor)[0]
208        output_tensor = array_ops.reshape(
209            tensor, shape=(batch_size, num_elements))
210        output_tensors.append(output_tensor)
211        if cols_to_vars is not None:
212          # Retrieve any variables created (some _DenseColumn's don't create
213          # variables, in which case an empty list is returned).
214          cols_to_vars[column] = ops.get_collection(
215              ops.GraphKeys.GLOBAL_VARIABLES,
216              scope=variable_scope.get_variable_scope().name)
217        if cols_to_output_tensors is not None:
218          cols_to_output_tensors[column] = output_tensor
219    _verify_static_batch_size_equality(output_tensors, ordered_columns)
220    return array_ops.concat(output_tensors, 1)
221
222  # If we're constructing from the `make_template`, that by default adds a
223  # variable scope with the name of the layer. In that case, we dont want to
224  # add another `variable_scope` as that would break checkpoints.
225  if from_template:
226    return _get_logits()
227  else:
228    with variable_scope.variable_scope(
229        scope, default_name='input_layer', values=features.values()):
230      return _get_logits()
231
232
233@tf_export(v1=['feature_column.input_layer'])
234def input_layer(features,
235                feature_columns,
236                weight_collections=None,
237                trainable=True,
238                cols_to_vars=None,
239                cols_to_output_tensors=None):
240  """Returns a dense `Tensor` as input layer based on given `feature_columns`.
241
242  Generally a single example in training data is described with FeatureColumns.
243  At the first layer of the model, this column oriented data should be converted
244  to a single `Tensor`.
245
246  Example:
247
248  ```python
249  price = numeric_column('price')
250  keywords_embedded = embedding_column(
251      categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
252  columns = [price, keywords_embedded, ...]
253  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
254  dense_tensor = input_layer(features, columns)
255  for units in [128, 64, 32]:
256    dense_tensor = tf.compat.v1.layers.dense(dense_tensor, units, tf.nn.relu)
257  prediction = tf.compat.v1.layers.dense(dense_tensor, 1)
258  ```
259
260  Args:
261    features: A mapping from key to tensors. `_FeatureColumn`s look up via these
262      keys. For example `numeric_column('price')` will look at 'price' key in
263      this dict. Values can be a `SparseTensor` or a `Tensor` depends on
264      corresponding `_FeatureColumn`.
265    feature_columns: An iterable containing the FeatureColumns to use as inputs
266      to your model. All items should be instances of classes derived from
267      `_DenseColumn` such as `numeric_column`, `embedding_column`,
268      `bucketized_column`, `indicator_column`. If you have categorical features,
269      you can wrap them with an `embedding_column` or `indicator_column`.
270    weight_collections: A list of collection names to which the Variable will be
271      added. Note that variables will also be added to collections
272      `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
273    trainable: If `True` also add the variable to the graph collection
274      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
275    cols_to_vars: If not `None`, must be a dictionary that will be filled with a
276      mapping from `_FeatureColumn` to list of `Variable`s.  For example, after
277      the call, we might have cols_to_vars =
278      {_EmbeddingColumn(
279        categorical_column=_HashedCategoricalColumn(
280          key='sparse_feature', hash_bucket_size=5, dtype=tf.string),
281        dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
282                        <tf.Variable 'some_variable:1' shape=(5, 10)]}
283      If a column creates no variables, its value will be an empty list.
284    cols_to_output_tensors: If not `None`, must be a dictionary that will be
285      filled with a mapping from '_FeatureColumn' to the associated
286      output `Tensor`s.
287
288  Returns:
289    A `Tensor` which represents input layer of a model. Its shape
290    is (batch_size, first_layer_dimension) and its dtype is `float32`.
291    first_layer_dimension is determined based on given `feature_columns`.
292
293  Raises:
294    ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
295  """
296  return _internal_input_layer(
297      features,
298      feature_columns,
299      weight_collections=weight_collections,
300      trainable=trainable,
301      cols_to_vars=cols_to_vars,
302      cols_to_output_tensors=cols_to_output_tensors)
303
304
305# TODO(akshayka): InputLayer should be a subclass of Layer, and it
306# should implement the logic in input_layer using Layer's build-and-call
307# paradigm; input_layer should create an instance of InputLayer and
308# return the result of invoking its apply method, just as functional layers do.
309class InputLayer(object):
310  """An object-oriented version of `input_layer` that reuses variables."""
311
312  def __init__(self,
313               feature_columns,
314               weight_collections=None,
315               trainable=True,
316               cols_to_vars=None,
317               name='feature_column_input_layer',
318               create_scope_now=True):
319    """See `input_layer`."""
320
321    self._feature_columns = feature_columns
322    self._weight_collections = weight_collections
323    self._trainable = trainable
324    self._cols_to_vars = cols_to_vars
325    self._name = name
326    self._input_layer_template = template.make_template(
327        self._name, _internal_input_layer, create_scope_now_=create_scope_now)
328    self._scope = self._input_layer_template.variable_scope
329
330  def __call__(self, features):
331    return self._input_layer_template(
332        features=features,
333        feature_columns=self._feature_columns,
334        weight_collections=self._weight_collections,
335        trainable=self._trainable,
336        cols_to_vars=None,
337        from_template=True)
338
339  @property
340  def name(self):
341    return self._name
342
343  @property
344  def non_trainable_variables(self):
345    return self._input_layer_template.non_trainable_variables
346
347  @property
348  def non_trainable_weights(self):
349    return self._input_layer_template.non_trainable_weights
350
351  @property
352  def trainable_variables(self):
353    return self._input_layer_template.trainable_variables
354
355  @property
356  def trainable_weights(self):
357    return self._input_layer_template.trainable_weights
358
359  @property
360  def variables(self):
361    return self._input_layer_template.variables
362
363  @property
364  def weights(self):
365    return self._input_layer_template.weights
366
367
368@tf_export(v1=['feature_column.linear_model'])
369def linear_model(features,
370                 feature_columns,
371                 units=1,
372                 sparse_combiner='sum',
373                 weight_collections=None,
374                 trainable=True,
375                 cols_to_vars=None):
376  """Returns a linear prediction `Tensor` based on given `feature_columns`.
377
378  This function generates a weighted sum based on output dimension `units`.
379  Weighted sum refers to logits in classification problems. It refers to the
380  prediction itself for linear regression problems.
381
382  Note on supported columns: `linear_model` treats categorical columns as
383  `indicator_column`s. To be specific, assume the input as `SparseTensor` looks
384  like:
385
386  ```python
387    shape = [2, 2]
388    {
389        [0, 0]: "a"
390        [1, 0]: "b"
391        [1, 1]: "c"
392    }
393  ```
394  `linear_model` assigns weights for the presence of "a", "b", "c' implicitly,
395  just like `indicator_column`, while `input_layer` explicitly requires wrapping
396  each of categorical columns with an `embedding_column` or an
397  `indicator_column`.
398
399  Example of usage:
400
401  ```python
402  price = numeric_column('price')
403  price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.])
404  keywords = categorical_column_with_hash_bucket("keywords", 10K)
405  keywords_price = crossed_column('keywords', price_buckets, ...)
406  columns = [price_buckets, keywords, keywords_price ...]
407  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
408  prediction = linear_model(features, columns)
409  ```
410
411  The `sparse_combiner` argument works as follows
412  For example, for two features represented as the categorical columns:
413
414  ```python
415    # Feature 1
416
417    shape = [2, 2]
418    {
419        [0, 0]: "a"
420        [0, 1]: "b"
421        [1, 0]: "c"
422    }
423
424    # Feature 2
425
426    shape = [2, 3]
427    {
428        [0, 0]: "d"
429        [1, 0]: "e"
430        [1, 1]: "f"
431        [1, 2]: "f"
432    }
433  ```
434
435  with `sparse_combiner` as "mean", the linear model outputs consequently
436  are:
437
438  ```
439    y_0 = 1.0 / 2.0 * ( w_a + w_b ) + w_d + b
440    y_1 = w_c + 1.0 / 3.0 * ( w_e + 2.0 * w_f ) + b
441  ```
442
443  where `y_i` is the output, `b` is the bias, and `w_x` is the weight
444  assigned to the presence of `x` in the input features.
445
446  Args:
447    features: A mapping from key to tensors. `_FeatureColumn`s look up via these
448      keys. For example `numeric_column('price')` will look at 'price' key in
449      this dict. Values are `Tensor` or `SparseTensor` depending on
450      corresponding `_FeatureColumn`.
451    feature_columns: An iterable containing the FeatureColumns to use as inputs
452      to your model. All items should be instances of classes derived from
453      `_FeatureColumn`s.
454    units: An integer, dimensionality of the output space. Default value is 1.
455    sparse_combiner: A string specifying how to reduce if a categorical column
456      is multivalent. Except `numeric_column`, almost all columns passed to
457      `linear_model` are considered as categorical columns.  It combines each
458      categorical column independently. Currently "mean", "sqrtn" and "sum" are
459      supported, with "sum" the default for linear model. "sqrtn" often achieves
460      good accuracy, in particular with bag-of-words columns.
461        * "sum": do not normalize features in the column
462        * "mean": do l1 normalization on features in the column
463        * "sqrtn": do l2 normalization on features in the column
464    weight_collections: A list of collection names to which the Variable will be
465      added. Note that, variables will also be added to collections
466      `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
467    trainable: If `True` also add the variable to the graph collection
468      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
469    cols_to_vars: If not `None`, must be a dictionary that will be filled with a
470      mapping from `_FeatureColumn` to associated list of `Variable`s.  For
471      example, after the call, we might have cols_to_vars = {
472        _NumericColumn(
473          key='numeric_feature1', shape=(1,):
474        [<tf.Variable 'linear_model/price2/weights:0' shape=(1, 1)>],
475        'bias': [<tf.Variable 'linear_model/bias_weights:0' shape=(1,)>],
476        _NumericColumn(
477          key='numeric_feature2', shape=(2,)):
478        [<tf.Variable 'linear_model/price1/weights:0' shape=(2, 1)>]}
479      If a column creates no variables, its value will be an empty list. Note
480      that cols_to_vars will also contain a string key 'bias' that maps to a
481      list of Variables.
482
483  Returns:
484    A `Tensor` which represents predictions/logits of a linear model. Its shape
485    is (batch_size, units) and its dtype is `float32`.
486
487  Raises:
488    ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
489      nor `_CategoricalColumn`.
490  """
491  with variable_scope.variable_scope(None, 'linear_model') as vs:
492    model_name = _strip_leading_slashes(vs.name)
493  linear_model_layer = _LinearModel(
494      feature_columns=feature_columns,
495      units=units,
496      sparse_combiner=sparse_combiner,
497      weight_collections=weight_collections,
498      trainable=trainable,
499      name=model_name)
500  retval = linear_model_layer(features)  # pylint: disable=not-callable
501  if cols_to_vars is not None:
502    cols_to_vars.update(linear_model_layer.cols_to_vars())
503  return retval
504
505
506def _add_to_collections(var, weight_collections):
507  """Adds a var to the list of weight_collections provided.
508
509  Handles the case for partitioned and non-partitioned variables.
510
511  Args:
512    var: A variable or Partitioned Variable.
513    weight_collections: List of collections to add variable to.
514  """
515  for weight_collection in weight_collections:
516    # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
517    if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
518      continue
519    # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
520    # so that we don't have to do this check.
521    if isinstance(var, variables.PartitionedVariable):
522      for constituent_var in list(var):
523        ops.add_to_collection(weight_collection, constituent_var)
524    else:
525      ops.add_to_collection(weight_collection, var)
526
527
528class _FCLinearWrapper(base.Layer):
529  """Wraps a _FeatureColumn in a layer for use in a linear model.
530
531  See `linear_model` above.
532  """
533
534  def __init__(self,
535               feature_column,
536               units=1,
537               sparse_combiner='sum',
538               weight_collections=None,
539               trainable=True,
540               name=None,
541               **kwargs):
542    super(_FCLinearWrapper, self).__init__(
543        trainable=trainable, name=name, **kwargs)
544    self._feature_column = feature_column
545    self._units = units
546    self._sparse_combiner = sparse_combiner
547    self._weight_collections = weight_collections
548
549  def build(self, _):
550    if isinstance(self._feature_column, _CategoricalColumn):
551      weight = self.add_variable(
552          name='weights',
553          shape=(self._feature_column._num_buckets, self._units),  # pylint: disable=protected-access
554          initializer=init_ops.zeros_initializer(),
555          trainable=self.trainable)
556    else:
557      num_elements = self._feature_column._variable_shape.num_elements()  # pylint: disable=protected-access
558      weight = self.add_variable(
559          name='weights',
560          shape=[num_elements, self._units],
561          initializer=init_ops.zeros_initializer(),
562          trainable=self.trainable)
563    _add_to_collections(weight, self._weight_collections)
564    self._weight_var = weight
565    self.built = True
566
567  def call(self, builder):
568    weighted_sum = _create_weighted_sum(
569        column=self._feature_column,
570        builder=builder,
571        units=self._units,
572        sparse_combiner=self._sparse_combiner,
573        weight_collections=self._weight_collections,
574        trainable=self.trainable,
575        weight_var=self._weight_var)
576    return weighted_sum
577
578
579class _BiasLayer(base.Layer):
580  """A layer for the bias term.
581  """
582
583  def __init__(self,
584               units=1,
585               trainable=True,
586               weight_collections=None,
587               name=None,
588               **kwargs):
589    super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
590    self._units = units
591    self._weight_collections = weight_collections
592
593  def build(self, _):
594    self._bias_variable = self.add_variable(
595        'bias_weights',
596        shape=[self._units],
597        initializer=init_ops.zeros_initializer(),
598        trainable=self.trainable)
599    _add_to_collections(self._bias_variable, self._weight_collections)
600    self.built = True
601
602  def call(self, _):
603    return self._bias_variable
604
605
606def _get_expanded_variable_list(variable):
607  if (isinstance(variable, variables.Variable) or
608      resource_variable_ops.is_resource_variable(variable)):
609    return [variable]  # Single variable case.
610  else:  # Must be a PartitionedVariable, so convert into a list.
611    return list(variable)
612
613
614def _strip_leading_slashes(name):
615  return name.rsplit('/', 1)[-1]
616
617
618class _LinearModel(base.Layer):
619  """Creates a linear model using feature columns.
620
621  See `linear_model` for details.
622  """
623
624  def __init__(self,
625               feature_columns,
626               units=1,
627               sparse_combiner='sum',
628               weight_collections=None,
629               trainable=True,
630               name=None,
631               **kwargs):
632    super(_LinearModel, self).__init__(name=name, **kwargs)
633    # We force the keras_style to be True here, as a workaround to not being
634    # able to inherit keras.layers.Layer as base class. Setting this will let
635    # us skip all the legacy behavior for base.Layer.
636    # Also note that we use Layer as base class, instead of Model, since there
637    # isn't any Model specific behavior gets used, eg compile/fit.
638    self._keras_style = True
639    self._feature_columns = _normalize_feature_columns(
640        feature_columns)
641    self._weight_collections = list(weight_collections or [])
642    if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
643      self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
644    if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
645      self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
646
647    column_layers = {}
648    for column in sorted(self._feature_columns, key=lambda x: x.name):
649      with variable_scope.variable_scope(
650          None, default_name=column._var_scope_name) as vs:  # pylint: disable=protected-access
651        # Having the fully expressed variable scope name ends up doubly
652        # expressing the outer scope (scope with which this method was called)
653        # in the name of the variable that would get created.
654        column_name = _strip_leading_slashes(vs.name)
655      column_layer = _FCLinearWrapper(column, units, sparse_combiner,
656                                      self._weight_collections, trainable,
657                                      column_name, **kwargs)
658      column_layers[column_name] = column_layer
659    self._column_layers = self._add_layers(column_layers)
660    self._bias_layer = _BiasLayer(
661        units=units,
662        trainable=trainable,
663        weight_collections=self._weight_collections,
664        name='bias_layer',
665        **kwargs)
666    self._cols_to_vars = {}
667
668  def cols_to_vars(self):
669    """Returns a dict mapping _FeatureColumns to variables.
670
671    See `linear_model` for more information.
672    This is not populated till `call` is called i.e. layer is built.
673    """
674    return self._cols_to_vars
675
676  def call(self, features):
677    with variable_scope.variable_scope(self.name):
678      for column in self._feature_columns:
679        if not isinstance(column, (_DenseColumn, _CategoricalColumn)):
680          raise ValueError(
681              'Items of feature_columns must be either a '
682              '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
683      weighted_sums = []
684      ordered_columns = []
685      builder = _LazyBuilder(features)
686      for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
687        column = layer._feature_column  # pylint: disable=protected-access
688        ordered_columns.append(column)
689        weighted_sum = layer(builder)
690        weighted_sums.append(weighted_sum)
691        self._cols_to_vars[column] = ops.get_collection(
692            ops.GraphKeys.GLOBAL_VARIABLES, scope=layer.scope_name)
693
694      _verify_static_batch_size_equality(weighted_sums, ordered_columns)
695      predictions_no_bias = math_ops.add_n(
696          weighted_sums, name='weighted_sum_no_bias')
697      predictions = nn_ops.bias_add(
698          predictions_no_bias,
699          self._bias_layer(  # pylint: disable=not-callable
700              builder,
701              scope=variable_scope.get_variable_scope()),  # pylint: disable=not-callable
702          name='weighted_sum')
703      bias = self._bias_layer.variables[0]
704      self._cols_to_vars['bias'] = _get_expanded_variable_list(bias)
705    return predictions
706
707  def _add_layers(self, layers):
708    # "Magic" required for keras.Model classes to track all the variables in
709    # a list of layers.Layer objects.
710    # TODO(ashankar): Figure out API so user code doesn't have to do this.
711    for name, layer in layers.items():
712      setattr(self, 'layer-%s' % name, layer)
713    return layers
714
715
716def _transform_features(features, feature_columns):
717  """Returns transformed features based on features columns passed in.
718
719  Please note that most probably you would not need to use this function. Please
720  check `input_layer` and `linear_model` to see whether they will
721  satisfy your use case or not.
722
723  Example:
724
725  ```python
726  # Define features and transformations
727  crosses_a_x_b = crossed_column(
728      columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
729  price_buckets = bucketized_column(
730      source_column=numeric_column("price"), boundaries=[...])
731
732  columns = [crosses_a_x_b, price_buckets]
733  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
734  transformed = transform_features(features=features, feature_columns=columns)
735
736  assertCountEqual(columns, transformed.keys())
737  ```
738
739  Args:
740    features: A mapping from key to tensors. `_FeatureColumn`s look up via these
741      keys. For example `numeric_column('price')` will look at 'price' key in
742      this dict. Values can be a `SparseTensor` or a `Tensor` depends on
743      corresponding `_FeatureColumn`.
744    feature_columns: An iterable containing all the `_FeatureColumn`s.
745
746  Returns:
747    A `dict` mapping `_FeatureColumn` to `Tensor` and `SparseTensor` values.
748  """
749  feature_columns = _normalize_feature_columns(feature_columns)
750  outputs = {}
751  with ops.name_scope(
752      None, default_name='transform_features', values=features.values()):
753    builder = _LazyBuilder(features)
754    for column in sorted(feature_columns, key=lambda x: x.name):
755      with ops.name_scope(None, default_name=column.name):
756        outputs[column] = builder.get(column)
757  return outputs
758
759
760@tf_export(v1=['feature_column.make_parse_example_spec'])
761def make_parse_example_spec(feature_columns):
762  """Creates parsing spec dictionary from input feature_columns.
763
764  The returned dictionary can be used as arg 'features' in
765  `tf.io.parse_example`.
766
767  Typical usage example:
768
769  ```python
770  # Define features and transformations
771  feature_a = categorical_column_with_vocabulary_file(...)
772  feature_b = numeric_column(...)
773  feature_c_bucketized = bucketized_column(numeric_column("feature_c"), ...)
774  feature_a_x_feature_c = crossed_column(
775      columns=["feature_a", feature_c_bucketized], ...)
776
777  feature_columns = set(
778      [feature_b, feature_c_bucketized, feature_a_x_feature_c])
779  features = tf.io.parse_example(
780      serialized=serialized_examples,
781      features=make_parse_example_spec(feature_columns))
782  ```
783
784  For the above example, make_parse_example_spec would return the dict:
785
786  ```python
787  {
788      "feature_a": parsing_ops.VarLenFeature(tf.string),
789      "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32),
790      "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32)
791  }
792  ```
793
794  Args:
795    feature_columns: An iterable containing all feature columns. All items
796      should be instances of classes derived from `_FeatureColumn`.
797
798  Returns:
799    A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`
800    value.
801
802  Raises:
803    ValueError: If any of the given `feature_columns` is not a `_FeatureColumn`
804      instance.
805  """
806  result = {}
807  for column in feature_columns:
808    if not isinstance(column, _FeatureColumn):
809      raise ValueError(
810          'All feature_columns must be _FeatureColumn instances. '
811          'Given: {}'.format(column))
812    config = column._parse_example_spec  # pylint: disable=protected-access
813    for key, value in six.iteritems(config):
814      if key in result and value != result[key]:
815        raise ValueError(
816            'feature_columns contain different parse_spec for key '
817            '{}. Given {} and {}'.format(key, value, result[key]))
818    result.update(config)
819  return result
820
821
822def _embedding_column(categorical_column,
823                      dimension,
824                      combiner='mean',
825                      initializer=None,
826                      ckpt_to_load_from=None,
827                      tensor_name_in_ckpt=None,
828                      max_norm=None,
829                      trainable=True,
830                      use_safe_embedding_lookup=True):
831  """`_DenseColumn` that converts from sparse, categorical input.
832
833  Use this when your inputs are sparse, but you want to convert them to a dense
834  representation (e.g., to feed to a DNN).
835
836  Inputs must be a `_CategoricalColumn` created by any of the
837  `categorical_column_*` function. Here is an example of using
838  `embedding_column` with `DNNClassifier`:
839
840  ```python
841  video_id = categorical_column_with_identity(
842      key='video_id', num_buckets=1000000, default_value=0)
843  columns = [embedding_column(video_id, 9),...]
844
845  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
846
847  label_column = ...
848  def input_fn():
849    features = tf.io.parse_example(
850        ..., features=make_parse_example_spec(columns + [label_column]))
851    labels = features.pop(label_column.name)
852    return features, labels
853
854  estimator.train(input_fn=input_fn, steps=100)
855  ```
856
857  Here is an example using `embedding_column` with model_fn:
858
859  ```python
860  def model_fn(features, ...):
861    video_id = categorical_column_with_identity(
862        key='video_id', num_buckets=1000000, default_value=0)
863    columns = [embedding_column(video_id, 9),...]
864    dense_tensor = input_layer(features, columns)
865    # Form DNN layers, calculate loss, and return EstimatorSpec.
866    ...
867  ```
868
869  Args:
870    categorical_column: A `_CategoricalColumn` created by a
871      `categorical_column_with_*` function. This column produces the sparse IDs
872      that are inputs to the embedding lookup.
873    dimension: An integer specifying dimension of the embedding, must be > 0.
874    combiner: A string specifying how to reduce if there are multiple entries
875      in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
876      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
877      with bag-of-words columns. Each of this can be thought as example level
878      normalizations on the column. For more information, see
879      `tf.embedding_lookup_sparse`.
880    initializer: A variable initializer function to be used in embedding
881      variable initialization. If not specified, defaults to
882      `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and
883      standard deviation `1/sqrt(dimension)`.
884    ckpt_to_load_from: String representing checkpoint name/pattern from which to
885      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
886    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
887      which to restore the column weights. Required if `ckpt_to_load_from` is
888      not `None`.
889    max_norm: If not `None`, embedding values are l2-normalized to this value.
890    trainable: Whether or not the embedding is trainable. Default is True.
891    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
892      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
893      there are no empty rows and all weights and ids are positive at the
894      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
895      input tensors. Defaults to true, consider turning off if the above checks
896      are not needed. Note that having empty rows will not trigger any error
897      though the output result might be 0 or omitted.
898
899  Returns:
900    `_DenseColumn` that converts from sparse input.
901
902  Raises:
903    ValueError: if `dimension` not > 0.
904    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
905      is specified.
906    ValueError: if `initializer` is specified and is not callable.
907    RuntimeError: If eager execution is enabled.
908  """
909  if (dimension is None) or (dimension < 1):
910    raise ValueError('Invalid dimension {}.'.format(dimension))
911  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
912    raise ValueError('Must specify both `ckpt_to_load_from` and '
913                     '`tensor_name_in_ckpt` or none of them.')
914
915  if (initializer is not None) and (not callable(initializer)):
916    raise ValueError('initializer must be callable if specified. '
917                     'Embedding of column_name: {}'.format(
918                         categorical_column.name))
919  if initializer is None:
920    initializer = init_ops.truncated_normal_initializer(
921        mean=0.0, stddev=1 / math.sqrt(dimension))
922
923  embedding_shape = categorical_column._num_buckets, dimension  # pylint: disable=protected-access
924
925  def _creator(weight_collections, scope):
926    embedding_column_layer = _EmbeddingColumnLayer(
927        embedding_shape=embedding_shape,
928        initializer=initializer,
929        weight_collections=weight_collections,
930        trainable=trainable,
931        name='embedding_column_layer')
932    return embedding_column_layer(None, scope=scope)  # pylint: disable=not-callable
933
934  return _EmbeddingColumn(
935      categorical_column=categorical_column,
936      dimension=dimension,
937      combiner=combiner,
938      layer_creator=_creator,
939      ckpt_to_load_from=ckpt_to_load_from,
940      tensor_name_in_ckpt=tensor_name_in_ckpt,
941      max_norm=max_norm,
942      trainable=trainable,
943      use_safe_embedding_lookup=use_safe_embedding_lookup)
944
945
946def _numeric_column(key,
947                    shape=(1,),
948                    default_value=None,
949                    dtype=dtypes.float32,
950                    normalizer_fn=None):
951  """Represents real valued or numerical features.
952
953  Example:
954
955  ```python
956  price = numeric_column('price')
957  columns = [price, ...]
958  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
959  dense_tensor = input_layer(features, columns)
960
961  # or
962  bucketized_price = bucketized_column(price, boundaries=[...])
963  columns = [bucketized_price, ...]
964  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
965  linear_prediction = linear_model(features, columns)
966  ```
967
968  Args:
969    key: A unique string identifying the input feature. It is used as the
970      column name and the dictionary key for feature parsing configs, feature
971      `Tensor` objects, and feature columns.
972    shape: An iterable of integers specifies the shape of the `Tensor`. An
973      integer can be given which means a single dimension `Tensor` with given
974      width. The `Tensor` representing the column will have the shape of
975      [batch_size] + `shape`.
976    default_value: A single value compatible with `dtype` or an iterable of
977      values compatible with `dtype` which the column takes on during
978      `tf.Example` parsing if data is missing. A default value of `None` will
979      cause `tf.io.parse_example` to fail if an example does not contain this
980      column. If a single value is provided, the same value will be applied as
981      the default value for every item. If an iterable of values is provided,
982      the shape of the `default_value` should be equal to the given `shape`.
983    dtype: defines the type of values. Default value is `tf.float32`. Must be a
984      non-quantized, real integer or floating point type.
985    normalizer_fn: If not `None`, a function that can be used to normalize the
986      value of the tensor after `default_value` is applied for parsing.
987      Normalizer function takes the input `Tensor` as its argument, and returns
988      the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
989      even though the most common use case of this function is normalization, it
990      can be used for any kind of Tensorflow transformations.
991
992  Returns:
993    A `_NumericColumn`.
994
995  Raises:
996    TypeError: if any dimension in shape is not an int
997    ValueError: if any dimension in shape is not a positive integer
998    TypeError: if `default_value` is an iterable but not compatible with `shape`
999    TypeError: if `default_value` is not compatible with `dtype`.
1000    ValueError: if `dtype` is not convertible to `tf.float32`.
1001  """
1002  shape = _check_shape(shape, key)
1003  if not (dtype.is_integer or dtype.is_floating):
1004    raise ValueError('dtype must be convertible to float. '
1005                     'dtype: {}, key: {}'.format(dtype, key))
1006  default_value = fc_utils.check_default_value(
1007      shape, default_value, dtype, key)
1008
1009  if normalizer_fn is not None and not callable(normalizer_fn):
1010    raise TypeError(
1011        'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
1012
1013  fc_utils.assert_key_is_string(key)
1014  return _NumericColumn(
1015      key,
1016      shape=shape,
1017      default_value=default_value,
1018      dtype=dtype,
1019      normalizer_fn=normalizer_fn)
1020
1021
1022def _bucketized_column(source_column, boundaries):
1023  """Represents discretized dense input.
1024
1025  Buckets include the left boundary, and exclude the right boundary. Namely,
1026  `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
1027  `[1., 2.)`, and `[2., +inf)`.
1028
1029  For example, if the inputs are
1030
1031  ```python
1032  boundaries = [0, 10, 100]
1033  input tensor = [[-5, 10000]
1034                  [150,   10]
1035                  [5,    100]]
1036  ```
1037
1038  then the output will be
1039
1040  ```python
1041  output = [[0, 3]
1042            [3, 2]
1043            [1, 3]]
1044  ```
1045
1046  Example:
1047
1048  ```python
1049  price = numeric_column('price')
1050  bucketized_price = bucketized_column(price, boundaries=[...])
1051  columns = [bucketized_price, ...]
1052  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1053  linear_prediction = linear_model(features, columns)
1054
1055  # or
1056  columns = [bucketized_price, ...]
1057  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1058  dense_tensor = input_layer(features, columns)
1059  ```
1060
1061  A `bucketized_column` can also be crossed with another categorical column
1062  using `crossed_column`:
1063
1064  ```python
1065  price = numeric_column('price')
1066  # bucketized_column converts numerical feature to a categorical one.
1067  bucketized_price = bucketized_column(price, boundaries=[...])
1068  # 'keywords' is a string feature.
1069  price_x_keywords = crossed_column([bucketized_price, 'keywords'], 50K)
1070  columns = [price_x_keywords, ...]
1071  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1072  linear_prediction = linear_model(features, columns)
1073  ```
1074
1075  Args:
1076    source_column: A one-dimensional dense column which is generated with
1077      `numeric_column`.
1078    boundaries: A sorted list or tuple of floats specifying the boundaries.
1079
1080  Returns:
1081    A `_BucketizedColumn`.
1082
1083  Raises:
1084    ValueError: If `source_column` is not a numeric column, or if it is not
1085      one-dimensional.
1086    ValueError: If `boundaries` is not a sorted list or tuple.
1087  """
1088  if not isinstance(source_column, _NumericColumn):
1089    raise ValueError(
1090        'source_column must be a column generated with numeric_column(). '
1091        'Given: {}'.format(source_column))
1092  if len(source_column.shape) > 1:
1093    raise ValueError(
1094        'source_column must be one-dimensional column. '
1095        'Given: {}'.format(source_column))
1096  if (not boundaries or
1097      not (isinstance(boundaries, list) or isinstance(boundaries, tuple))):
1098    raise ValueError('boundaries must be a sorted list.')
1099  for i in range(len(boundaries) - 1):
1100    if boundaries[i] >= boundaries[i + 1]:
1101      raise ValueError('boundaries must be a sorted list.')
1102  return _BucketizedColumn(source_column, tuple(boundaries))
1103
1104
1105def _categorical_column_with_hash_bucket(key,
1106                                         hash_bucket_size,
1107                                         dtype=dtypes.string):
1108  """Represents sparse feature where ids are set by hashing.
1109
1110  Use this when your sparse features are in string or integer format, and you
1111  want to distribute your inputs into a finite number of buckets by hashing.
1112  output_id = Hash(input_feature_string) % bucket_size for string type input.
1113  For int type input, the value is converted to its string representation first
1114  and then hashed by the same formula.
1115
1116  For input dictionary `features`, `features[key]` is either `Tensor` or
1117  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1118  and `''` for string, which will be dropped by this feature column.
1119
1120  Example:
1121
1122  ```python
1123  keywords = categorical_column_with_hash_bucket("keywords", 10K)
1124  columns = [keywords, ...]
1125  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1126  linear_prediction = linear_model(features, columns)
1127
1128  # or
1129  keywords_embedded = embedding_column(keywords, 16)
1130  columns = [keywords_embedded, ...]
1131  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1132  dense_tensor = input_layer(features, columns)
1133  ```
1134
1135  Args:
1136    key: A unique string identifying the input feature. It is used as the
1137      column name and the dictionary key for feature parsing configs, feature
1138      `Tensor` objects, and feature columns.
1139    hash_bucket_size: An int > 1. The number of buckets.
1140    dtype: The type of features. Only string and integer types are supported.
1141
1142  Returns:
1143    A `_HashedCategoricalColumn`.
1144
1145  Raises:
1146    ValueError: `hash_bucket_size` is not greater than 1.
1147    ValueError: `dtype` is neither string nor integer.
1148  """
1149  if hash_bucket_size is None:
1150    raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
1151
1152  if hash_bucket_size < 1:
1153    raise ValueError('hash_bucket_size must be at least 1. '
1154                     'hash_bucket_size: {}, key: {}'.format(
1155                         hash_bucket_size, key))
1156
1157  fc_utils.assert_key_is_string(key)
1158  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1159
1160  return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
1161
1162
1163def _categorical_column_with_vocabulary_file(key,
1164                                             vocabulary_file,
1165                                             vocabulary_size=None,
1166                                             num_oov_buckets=0,
1167                                             default_value=None,
1168                                             dtype=dtypes.string):
1169  """A `_CategoricalColumn` with a vocabulary file.
1170
1171  Use this when your inputs are in string or integer format, and you have a
1172  vocabulary file that maps each value to an integer ID. By default,
1173  out-of-vocabulary values are ignored. Use either (but not both) of
1174  `num_oov_buckets` and `default_value` to specify how to include
1175  out-of-vocabulary values.
1176
1177  For input dictionary `features`, `features[key]` is either `Tensor` or
1178  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1179  and `''` for string, which will be dropped by this feature column.
1180
1181  Example with `num_oov_buckets`:
1182  File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
1183  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1184  corresponding to its line number. All other values are hashed and assigned an
1185  ID 50-54.
1186
1187  ```python
1188  states = categorical_column_with_vocabulary_file(
1189      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
1190      num_oov_buckets=5)
1191  columns = [states, ...]
1192  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1193  linear_prediction = linear_model(features, columns)
1194  ```
1195
1196  Example with `default_value`:
1197  File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
1198  other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
1199  in input, and other values missing from the file, will be assigned ID 0. All
1200  others are assigned the corresponding line number 1-50.
1201
1202  ```python
1203  states = categorical_column_with_vocabulary_file(
1204      key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
1205      default_value=0)
1206  columns = [states, ...]
1207  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1208  linear_prediction, _, _ = linear_model(features, columns)
1209  ```
1210
1211  And to make an embedding with either:
1212
1213  ```python
1214  columns = [embedding_column(states, 3),...]
1215  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1216  dense_tensor = input_layer(features, columns)
1217  ```
1218
1219  Args:
1220    key: A unique string identifying the input feature. It is used as the
1221      column name and the dictionary key for feature parsing configs, feature
1222      `Tensor` objects, and feature columns.
1223    vocabulary_file: The vocabulary file name.
1224    vocabulary_size: Number of the elements in the vocabulary. This must be no
1225      greater than length of `vocabulary_file`, if less than length, later
1226      values are ignored. If None, it is set to the length of `vocabulary_file`.
1227    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1228      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1229      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1230      the input value. A positive `num_oov_buckets` can not be specified with
1231      `default_value`.
1232    default_value: The integer ID value to return for out-of-vocabulary feature
1233      values, defaults to `-1`. This can not be specified with a positive
1234      `num_oov_buckets`.
1235    dtype: The type of features. Only string and integer types are supported.
1236
1237  Returns:
1238    A `_CategoricalColumn` with a vocabulary file.
1239
1240  Raises:
1241    ValueError: `vocabulary_file` is missing or cannot be opened.
1242    ValueError: `vocabulary_size` is missing or < 1.
1243    ValueError: `num_oov_buckets` is a negative integer.
1244    ValueError: `num_oov_buckets` and `default_value` are both specified.
1245    ValueError: `dtype` is neither string nor integer.
1246  """
1247  if not vocabulary_file:
1248    raise ValueError('Missing vocabulary_file in {}.'.format(key))
1249
1250  if vocabulary_size is None:
1251    if not gfile.Exists(vocabulary_file):
1252      raise ValueError('vocabulary_file in {} does not exist.'.format(key))
1253
1254    with gfile.GFile(vocabulary_file) as f:
1255      vocabulary_size = sum(1 for _ in f)
1256    logging.info(
1257        'vocabulary_size = %d in %s is inferred from the number of elements '
1258        'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
1259
1260  # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
1261  if vocabulary_size < 1:
1262    raise ValueError('Invalid vocabulary_size in {}.'.format(key))
1263  if num_oov_buckets:
1264    if default_value is not None:
1265      raise ValueError(
1266          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1267              key))
1268    if num_oov_buckets < 0:
1269      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1270          num_oov_buckets, key))
1271  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1272  fc_utils.assert_key_is_string(key)
1273  return _VocabularyFileCategoricalColumn(
1274      key=key,
1275      vocabulary_file=vocabulary_file,
1276      vocabulary_size=vocabulary_size,
1277      num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
1278      default_value=-1 if default_value is None else default_value,
1279      dtype=dtype)
1280
1281
1282def _categorical_column_with_vocabulary_list(key,
1283                                             vocabulary_list,
1284                                             dtype=None,
1285                                             default_value=-1,
1286                                             num_oov_buckets=0):
1287  """A `_CategoricalColumn` with in-memory vocabulary.
1288
1289  Use this when your inputs are in string or integer format, and you have an
1290  in-memory vocabulary mapping each value to an integer ID. By default,
1291  out-of-vocabulary values are ignored. Use either (but not both) of
1292  `num_oov_buckets` and `default_value` to specify how to include
1293  out-of-vocabulary values.
1294
1295  For input dictionary `features`, `features[key]` is either `Tensor` or
1296  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1297  and `''` for string, which will be dropped by this feature column.
1298
1299  Example with `num_oov_buckets`:
1300  In the following example, each input in `vocabulary_list` is assigned an ID
1301  0-3 corresponding to its index (e.g., input 'B' produces output 2). All other
1302  inputs are hashed and assigned an ID 4-5.
1303
1304  ```python
1305  colors = categorical_column_with_vocabulary_list(
1306      key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
1307      num_oov_buckets=2)
1308  columns = [colors, ...]
1309  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1310  linear_prediction, _, _ = linear_model(features, columns)
1311  ```
1312
1313  Example with `default_value`:
1314  In the following example, each input in `vocabulary_list` is assigned an ID
1315  0-4 corresponding to its index (e.g., input 'B' produces output 3). All other
1316  inputs are assigned `default_value` 0.
1317
1318
1319  ```python
1320  colors = categorical_column_with_vocabulary_list(
1321      key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
1322  columns = [colors, ...]
1323  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1324  linear_prediction, _, _ = linear_model(features, columns)
1325  ```
1326
1327  And to make an embedding with either:
1328
1329  ```python
1330  columns = [embedding_column(colors, 3),...]
1331  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1332  dense_tensor = input_layer(features, columns)
1333  ```
1334
1335  Args:
1336    key: A unique string identifying the input feature. It is used as the
1337      column name and the dictionary key for feature parsing configs, feature
1338      `Tensor` objects, and feature columns.
1339    vocabulary_list: An ordered iterable defining the vocabulary. Each feature
1340      is mapped to the index of its value (if present) in `vocabulary_list`.
1341      Must be castable to `dtype`.
1342    dtype: The type of features. Only string and integer types are supported.
1343      If `None`, it will be inferred from `vocabulary_list`.
1344    default_value: The integer ID value to return for out-of-vocabulary feature
1345      values, defaults to `-1`. This can not be specified with a positive
1346      `num_oov_buckets`.
1347    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1348      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1349      `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
1350      hash of the input value. A positive `num_oov_buckets` can not be specified
1351      with `default_value`.
1352
1353  Returns:
1354    A `_CategoricalColumn` with in-memory vocabulary.
1355
1356  Raises:
1357    ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
1358    ValueError: `num_oov_buckets` is a negative integer.
1359    ValueError: `num_oov_buckets` and `default_value` are both specified.
1360    ValueError: if `dtype` is not integer or string.
1361  """
1362  if (vocabulary_list is None) or (len(vocabulary_list) < 1):
1363    raise ValueError(
1364        'vocabulary_list {} must be non-empty, column_name: {}'.format(
1365            vocabulary_list, key))
1366  if len(set(vocabulary_list)) != len(vocabulary_list):
1367    raise ValueError(
1368        'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
1369            vocabulary_list, key))
1370  vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
1371  if num_oov_buckets:
1372    if default_value != -1:
1373      raise ValueError(
1374          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1375              key))
1376    if num_oov_buckets < 0:
1377      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1378          num_oov_buckets, key))
1379  fc_utils.assert_string_or_int(
1380      vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
1381  if dtype is None:
1382    dtype = vocabulary_dtype
1383  elif dtype.is_integer != vocabulary_dtype.is_integer:
1384    raise ValueError(
1385        'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
1386            dtype, vocabulary_dtype, key))
1387  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1388  fc_utils.assert_key_is_string(key)
1389
1390  return _VocabularyListCategoricalColumn(
1391      key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype,
1392      default_value=default_value, num_oov_buckets=num_oov_buckets)
1393
1394
1395def _categorical_column_with_identity(key, num_buckets, default_value=None):
1396  """A `_CategoricalColumn` that returns identity values.
1397
1398  Use this when your inputs are integers in the range `[0, num_buckets)`, and
1399  you want to use the input value itself as the categorical ID. Values outside
1400  this range will result in `default_value` if specified, otherwise it will
1401  fail.
1402
1403  Typically, this is used for contiguous ranges of integer indexes, but
1404  it doesn't have to be. This might be inefficient, however, if many of IDs
1405  are unused. Consider `categorical_column_with_hash_bucket` in that case.
1406
1407  For input dictionary `features`, `features[key]` is either `Tensor` or
1408  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1409  and `''` for string, which will be dropped by this feature column.
1410
1411  In the following examples, each input in the range `[0, 1000000)` is assigned
1412  the same value. All other inputs are assigned `default_value` 0. Note that a
1413  literal 0 in inputs will result in the same default ID.
1414
1415  Linear model:
1416
1417  ```python
1418  video_id = categorical_column_with_identity(
1419      key='video_id', num_buckets=1000000, default_value=0)
1420  columns = [video_id, ...]
1421  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1422  linear_prediction, _, _ = linear_model(features, columns)
1423  ```
1424
1425  Embedding for a DNN model:
1426
1427  ```python
1428  columns = [embedding_column(video_id, 9),...]
1429  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1430  dense_tensor = input_layer(features, columns)
1431  ```
1432
1433  Args:
1434    key: A unique string identifying the input feature. It is used as the
1435      column name and the dictionary key for feature parsing configs, feature
1436      `Tensor` objects, and feature columns.
1437    num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
1438    default_value: If set, values outside of range `[0, num_buckets)` will
1439      be replaced with this value. If not set, values >= num_buckets will
1440      cause a failure while values < 0 will be dropped.
1441
1442  Returns:
1443    A `_CategoricalColumn` that returns identity values.
1444
1445  Raises:
1446    ValueError: if `num_buckets` is less than one.
1447    ValueError: if `default_value` is not in range `[0, num_buckets)`.
1448  """
1449  if num_buckets < 1:
1450    raise ValueError(
1451        'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
1452  if (default_value is not None) and (
1453      (default_value < 0) or (default_value >= num_buckets)):
1454    raise ValueError(
1455        'default_value {} not in range [0, {}), column_name {}'.format(
1456            default_value, num_buckets, key))
1457  fc_utils.assert_key_is_string(key)
1458  return _IdentityCategoricalColumn(
1459      key=key, num_buckets=num_buckets, default_value=default_value)
1460
1461
1462def _indicator_column(categorical_column):
1463  """Represents multi-hot representation of given categorical column.
1464
1465  - For DNN model, `indicator_column` can be used to wrap any
1466    `categorical_column_*` (e.g., to feed to DNN). Consider to Use
1467    `embedding_column` if the number of buckets/unique(values) are large.
1468
1469  - For Wide (aka linear) model, `indicator_column` is the internal
1470    representation for categorical column when passing categorical column
1471    directly (as any element in feature_columns) to `linear_model`. See
1472    `linear_model` for details.
1473
1474  ```python
1475  name = indicator_column(categorical_column_with_vocabulary_list(
1476      'name', ['bob', 'george', 'wanda'])
1477  columns = [name, ...]
1478  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1479  dense_tensor = input_layer(features, columns)
1480
1481  dense_tensor == [[1, 0, 0]]  # If "name" bytes_list is ["bob"]
1482  dense_tensor == [[1, 0, 1]]  # If "name" bytes_list is ["bob", "wanda"]
1483  dense_tensor == [[2, 0, 0]]  # If "name" bytes_list is ["bob", "bob"]
1484  ```
1485
1486  Args:
1487    categorical_column: A `_CategoricalColumn` which is created by
1488      `categorical_column_with_*` or `crossed_column` functions.
1489
1490  Returns:
1491    An `_IndicatorColumn`.
1492  """
1493  return _IndicatorColumn(categorical_column)
1494
1495
1496def _weighted_categorical_column(categorical_column,
1497                                 weight_feature_key,
1498                                 dtype=dtypes.float32):
1499  """Applies weight values to a `_CategoricalColumn`.
1500
1501  Use this when each of your sparse inputs has both an ID and a value. For
1502  example, if you're representing text documents as a collection of word
1503  frequencies, you can provide 2 parallel sparse input features ('terms' and
1504  'frequencies' below).
1505
1506  Example:
1507
1508  Input `tf.Example` objects:
1509
1510  ```proto
1511  [
1512    features {
1513      feature {
1514        key: "terms"
1515        value {bytes_list {value: "very" value: "model"}}
1516      }
1517      feature {
1518        key: "frequencies"
1519        value {float_list {value: 0.3 value: 0.1}}
1520      }
1521    },
1522    features {
1523      feature {
1524        key: "terms"
1525        value {bytes_list {value: "when" value: "course" value: "human"}}
1526      }
1527      feature {
1528        key: "frequencies"
1529        value {float_list {value: 0.4 value: 0.1 value: 0.2}}
1530      }
1531    }
1532  ]
1533  ```
1534
1535  ```python
1536  categorical_column = categorical_column_with_hash_bucket(
1537      column_name='terms', hash_bucket_size=1000)
1538  weighted_column = weighted_categorical_column(
1539      categorical_column=categorical_column, weight_feature_key='frequencies')
1540  columns = [weighted_column, ...]
1541  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1542  linear_prediction, _, _ = linear_model(features, columns)
1543  ```
1544
1545  This assumes the input dictionary contains a `SparseTensor` for key
1546  'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have
1547  the same indices and dense shape.
1548
1549  Args:
1550    categorical_column: A `_CategoricalColumn` created by
1551      `categorical_column_with_*` functions.
1552    weight_feature_key: String key for weight values.
1553    dtype: Type of weights, such as `tf.float32`. Only float and integer weights
1554      are supported.
1555
1556  Returns:
1557    A `_CategoricalColumn` composed of two sparse features: one represents id,
1558    the other represents weight (value) of the id feature in that example.
1559
1560  Raises:
1561    ValueError: if `dtype` is not convertible to float.
1562  """
1563  if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
1564    raise ValueError('dtype {} is not convertible to float.'.format(dtype))
1565  return _WeightedCategoricalColumn(
1566      categorical_column=categorical_column,
1567      weight_feature_key=weight_feature_key,
1568      dtype=dtype)
1569
1570
1571def _crossed_column(keys, hash_bucket_size, hash_key=None):
1572  """Returns a column for performing crosses of categorical features.
1573
1574  Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
1575  the transformation can be thought of as:
1576    Hash(cartesian product of features) % `hash_bucket_size`
1577
1578  For example, if the input features are:
1579
1580  * SparseTensor referred by first key:
1581
1582    ```python
1583    shape = [2, 2]
1584    {
1585        [0, 0]: "a"
1586        [1, 0]: "b"
1587        [1, 1]: "c"
1588    }
1589    ```
1590
1591  * SparseTensor referred by second key:
1592
1593    ```python
1594    shape = [2, 1]
1595    {
1596        [0, 0]: "d"
1597        [1, 0]: "e"
1598    }
1599    ```
1600
1601  then crossed feature will look like:
1602
1603  ```python
1604   shape = [2, 2]
1605  {
1606      [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
1607      [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
1608      [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
1609  }
1610  ```
1611
1612  Here is an example to create a linear model with crosses of string features:
1613
1614  ```python
1615  keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
1616  columns = [keywords_x_doc_terms, ...]
1617  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1618  linear_prediction = linear_model(features, columns)
1619  ```
1620
1621  You could also use vocabulary lookup before crossing:
1622
1623  ```python
1624  keywords = categorical_column_with_vocabulary_file(
1625      'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
1626  keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
1627  columns = [keywords_x_doc_terms, ...]
1628  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1629  linear_prediction = linear_model(features, columns)
1630  ```
1631
1632  If an input feature is of numeric type, you can use
1633  `categorical_column_with_identity`, or `bucketized_column`, as in the example:
1634
1635  ```python
1636  # vertical_id is an integer categorical feature.
1637  vertical_id = categorical_column_with_identity('vertical_id', 10K)
1638  price = numeric_column('price')
1639  # bucketized_column converts numerical feature to a categorical one.
1640  bucketized_price = bucketized_column(price, boundaries=[...])
1641  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1642  columns = [vertical_id_x_price, ...]
1643  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1644  linear_prediction = linear_model(features, columns)
1645  ```
1646
1647  To use crossed column in DNN model, you need to add it in an embedding column
1648  as in this example:
1649
1650  ```python
1651  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1652  vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
1653  dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...])
1654  ```
1655
1656  Args:
1657    keys: An iterable identifying the features to be crossed. Each element can
1658      be either:
1659      * string: Will use the corresponding feature which must be of string type.
1660      * `_CategoricalColumn`: Will use the transformed tensor produced by this
1661        column. Does not support hashed categorical column.
1662    hash_bucket_size: An int > 1. The number of buckets.
1663    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
1664      function to combine the crosses fingerprints on SparseCrossOp (optional).
1665
1666  Returns:
1667    A `_CrossedColumn`.
1668
1669  Raises:
1670    ValueError: If `len(keys) < 2`.
1671    ValueError: If any of the keys is neither a string nor `_CategoricalColumn`.
1672    ValueError: If any of the keys is `_HashedCategoricalColumn`.
1673    ValueError: If `hash_bucket_size < 1`.
1674  """
1675  if not hash_bucket_size or hash_bucket_size < 1:
1676    raise ValueError('hash_bucket_size must be > 1. '
1677                     'hash_bucket_size: {}'.format(hash_bucket_size))
1678  if not keys or len(keys) < 2:
1679    raise ValueError(
1680        'keys must be a list with length > 1. Given: {}'.format(keys))
1681  for key in keys:
1682    if (not isinstance(key, six.string_types) and
1683        not isinstance(key, _CategoricalColumn)):
1684      raise ValueError(
1685          'Unsupported key type. All keys must be either string, or '
1686          'categorical column except _HashedCategoricalColumn. '
1687          'Given: {}'.format(key))
1688    if isinstance(key, _HashedCategoricalColumn):
1689      raise ValueError(
1690          'categorical_column_with_hash_bucket is not supported for crossing. '
1691          'Hashing before crossing will increase probability of collision. '
1692          'Instead, use the feature name as a string. Given: {}'.format(key))
1693  return _CrossedColumn(
1694      keys=tuple(keys), hash_bucket_size=hash_bucket_size,
1695      hash_key=hash_key)
1696
1697
1698# TODO(rohanj): Clearly define semantics of this layer.
1699class _EmbeddingColumnLayer(base.Layer):
1700  """A layer that stores all the state required for a embedding column."""
1701
1702  def __init__(self,
1703               embedding_shape,
1704               initializer,
1705               weight_collections=None,
1706               trainable=True,
1707               name=None,
1708               **kwargs):
1709    """Constructor.
1710
1711    Args:
1712      embedding_shape: Shape of the embedding variable used for lookup.
1713      initializer: A variable initializer function to be used in embedding
1714        variable initialization.
1715      weight_collections: A list of collection names to which the Variable will
1716        be added. Note that, variables will also be added to collections
1717        `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
1718      trainable: If `True` also add the variable to the graph collection
1719        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
1720      name: Name of the layer
1721      **kwargs: keyword named properties.
1722    """
1723    super(_EmbeddingColumnLayer, self).__init__(
1724        trainable=trainable, name=name, **kwargs)
1725    self._embedding_shape = embedding_shape
1726    self._initializer = initializer
1727    self._weight_collections = weight_collections
1728
1729  def set_weight_collections(self, weight_collections):
1730    """Sets the weight collections for the layer.
1731
1732    Args:
1733      weight_collections: A list of collection names to which the Variable will
1734        be added.
1735    """
1736    self._weight_collections = weight_collections
1737
1738  def build(self, _):
1739    self._embedding_weight_var = self.add_variable(
1740        name='embedding_weights',
1741        shape=self._embedding_shape,
1742        dtype=dtypes.float32,
1743        initializer=self._initializer,
1744        trainable=self.trainable)
1745    if self._weight_collections and not context.executing_eagerly():
1746      _add_to_collections(self._embedding_weight_var, self._weight_collections)
1747    self.built = True
1748
1749  def call(self, _):
1750    return self._embedding_weight_var
1751
1752
1753@six.add_metaclass(abc.ABCMeta)
1754class _FeatureColumn(object):
1755  """Represents a feature column abstraction.
1756
1757  WARNING: Do not subclass this layer unless you know what you are doing:
1758  the API is subject to future changes.
1759
1760  To distinguish the concept of a feature family and a specific binary feature
1761  within a family, we refer to a feature family like "country" as a feature
1762  column. Following is an example feature in a `tf.Example` format:
1763    {key: "country",  value: [ "US" ]}
1764  In this example the value of feature is "US" and "country" refers to the
1765  column of the feature.
1766
1767  This class is an abstract class. User should not create instances of this.
1768  """
1769
1770  @abc.abstractproperty
1771  def name(self):
1772    """Returns string. Used for naming and for name_scope."""
1773    pass
1774
1775  def __lt__(self, other):
1776    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
1777
1778    Feature columns need to occasionally be sortable, for example when used as
1779    keys in a features dictionary passed to a layer.
1780
1781    In CPython, `__lt__` must be defined for all objects in the
1782    sequence being sorted. If any objects do not have an `__lt__` compatible
1783    with feature column objects (such as strings), then CPython will fall back
1784    to using the `__gt__` method below.
1785    https://docs.python.org/3/library/stdtypes.html#list.sort
1786
1787    Args:
1788      other: The other object to compare to.
1789
1790    Returns:
1791      True if the string representation of this object is lexicographically less
1792      than the string representation of `other`. For FeatureColumn objects,
1793      this looks like "<__main__.FeatureColumn object at 0xa>".
1794    """
1795    return str(self) < str(other)
1796
1797  def __gt__(self, other):
1798    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
1799
1800    Feature columns need to occasionally be sortable, for example when used as
1801    keys in a features dictionary passed to a layer.
1802
1803    `__gt__` is called when the "other" object being compared during the sort
1804    does not have `__lt__` defined.
1805    Example:
1806    ```
1807    # __lt__ only class
1808    class A():
1809      def __lt__(self, other): return str(self) < str(other)
1810
1811    a = A()
1812    a < "b" # True
1813    "0" < a # Error
1814
1815    # __lt__ and __gt__ class
1816    class B():
1817      def __lt__(self, other): return str(self) < str(other)
1818      def __gt__(self, other): return str(self) > str(other)
1819
1820    b = B()
1821    b < "c" # True
1822    "0" < b # True
1823    ```
1824
1825
1826    Args:
1827      other: The other object to compare to.
1828
1829    Returns:
1830      True if the string representation of this object is lexicographically
1831      greater than the string representation of `other`. For FeatureColumn
1832      objects, this looks like "<__main__.FeatureColumn object at 0xa>".
1833    """
1834    return str(self) > str(other)
1835
1836  @property
1837  def _var_scope_name(self):
1838    """Returns string. Used for variable_scope. Defaults to self.name."""
1839    return self.name
1840
1841  @abc.abstractmethod
1842  def _transform_feature(self, inputs):
1843    """Returns intermediate representation (usually a `Tensor`).
1844
1845    Uses `inputs` to create an intermediate representation (usually a `Tensor`)
1846    that other feature columns can use.
1847
1848    Example usage of `inputs`:
1849    Let's say a Feature column depends on raw feature ('raw') and another
1850    `_FeatureColumn` (input_fc). To access corresponding `Tensor`s, inputs will
1851    be used as follows:
1852
1853    ```python
1854    raw_tensor = inputs.get('raw')
1855    fc_tensor = inputs.get(input_fc)
1856    ```
1857
1858    Args:
1859      inputs: A `_LazyBuilder` object to access inputs.
1860
1861    Returns:
1862      Transformed feature `Tensor`.
1863    """
1864    pass
1865
1866  @abc.abstractproperty
1867  def _parse_example_spec(self):
1868    """Returns a `tf.Example` parsing spec as dict.
1869
1870    It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is
1871    a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
1872    supported objects. Please check documentation of `tf.io.parse_example` for
1873    all supported spec objects.
1874
1875    Let's say a Feature column depends on raw feature ('raw') and another
1876    `_FeatureColumn` (input_fc). One possible implementation of
1877    _parse_example_spec is as follows:
1878
1879    ```python
1880    spec = {'raw': tf.io.FixedLenFeature(...)}
1881    spec.update(input_fc._parse_example_spec)
1882    return spec
1883    ```
1884    """
1885    pass
1886
1887  def _reset_config(self):
1888    """Resets the configuration in the column.
1889
1890    Some feature columns e.g. embedding or shared embedding columns might
1891    have some state that is needed to be reset sometimes. Use this method
1892    in that scenario.
1893    """
1894
1895
1896class _DenseColumn(_FeatureColumn):
1897  """Represents a column which can be represented as `Tensor`.
1898
1899  WARNING: Do not subclass this layer unless you know what you are doing:
1900  the API is subject to future changes.
1901
1902  Some examples of this type are: numeric_column, embedding_column,
1903  indicator_column.
1904  """
1905
1906  @abc.abstractproperty
1907  def _variable_shape(self):
1908    """`TensorShape` of `_get_dense_tensor`, without batch dimension."""
1909    pass
1910
1911  @abc.abstractmethod
1912  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
1913    """Returns a `Tensor`.
1914
1915    The output of this function will be used by model-builder-functions. For
1916    example the pseudo code of `input_layer` will be like:
1917
1918    ```python
1919    def input_layer(features, feature_columns, ...):
1920      outputs = [fc._get_dense_tensor(...) for fc in feature_columns]
1921      return tf.concat(outputs)
1922    ```
1923
1924    Args:
1925      inputs: A `_LazyBuilder` object to access inputs.
1926      weight_collections: List of graph collections to which Variables (if any
1927        will be created) are added.
1928      trainable: If `True` also add variables to the graph collection
1929        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
1930
1931    Returns:
1932      `Tensor` of shape [batch_size] + `_variable_shape`.
1933    """
1934    pass
1935
1936
1937def _create_weighted_sum(column,
1938                         builder,
1939                         units,
1940                         sparse_combiner,
1941                         weight_collections,
1942                         trainable,
1943                         weight_var=None):
1944  """Creates a weighted sum for a dense/categorical column for linear_model."""
1945  if isinstance(column, _CategoricalColumn):
1946    return _create_categorical_column_weighted_sum(
1947        column=column,
1948        builder=builder,
1949        units=units,
1950        sparse_combiner=sparse_combiner,
1951        weight_collections=weight_collections,
1952        trainable=trainable,
1953        weight_var=weight_var)
1954  else:
1955    return _create_dense_column_weighted_sum(
1956        column=column,
1957        builder=builder,
1958        units=units,
1959        weight_collections=weight_collections,
1960        trainable=trainable,
1961        weight_var=weight_var)
1962
1963
1964def _create_dense_column_weighted_sum(column,
1965                                      builder,
1966                                      units,
1967                                      weight_collections,
1968                                      trainable,
1969                                      weight_var=None):
1970  """Create a weighted sum of a dense column for linear_model."""
1971  tensor = column._get_dense_tensor(  # pylint: disable=protected-access
1972      builder,
1973      weight_collections=weight_collections,
1974      trainable=trainable)
1975  num_elements = column._variable_shape.num_elements()  # pylint: disable=protected-access
1976  batch_size = array_ops.shape(tensor)[0]
1977  tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
1978  if weight_var is not None:
1979    weight = weight_var
1980  else:
1981    weight = variable_scope.get_variable(
1982        name='weights',
1983        shape=[num_elements, units],
1984        initializer=init_ops.zeros_initializer(),
1985        trainable=trainable,
1986        collections=weight_collections)
1987  return math_ops.matmul(tensor, weight, name='weighted_sum')
1988
1989
1990class _CategoricalColumn(_FeatureColumn):
1991  """Represents a categorical feature.
1992
1993  WARNING: Do not subclass this layer unless you know what you are doing:
1994  the API is subject to future changes.
1995
1996  A categorical feature typically handled with a `tf.sparse.SparseTensor` of
1997  IDs.
1998  """
1999
2000  IdWeightPair = collections.namedtuple(  # pylint: disable=invalid-name
2001      'IdWeightPair', ['id_tensor', 'weight_tensor'])
2002
2003  @abc.abstractproperty
2004  def _num_buckets(self):
2005    """Returns number of buckets in this sparse feature."""
2006    pass
2007
2008  @abc.abstractmethod
2009  def _get_sparse_tensors(self,
2010                          inputs,
2011                          weight_collections=None,
2012                          trainable=None):
2013    """Returns an IdWeightPair.
2014
2015    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
2016    weights.
2017
2018    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
2019    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
2020    `SparseTensor` of `float` or `None` to indicate all weights should be
2021    taken to be 1. If specified, `weight_tensor` must have exactly the same
2022    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
2023    output of a `VarLenFeature` which is a ragged matrix.
2024
2025    Args:
2026      inputs: A `LazyBuilder` as a cache to get input tensors required to
2027        create `IdWeightPair`.
2028      weight_collections: List of graph collections to which variables (if any
2029        will be created) are added.
2030      trainable: If `True` also add variables to the graph collection
2031        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.compat.v1.get_variable`).
2032    """
2033    pass
2034
2035
2036def _create_categorical_column_weighted_sum(column,
2037                                            builder,
2038                                            units,
2039                                            sparse_combiner,
2040                                            weight_collections,
2041                                            trainable,
2042                                            weight_var=None):
2043  # pylint: disable=g-doc-return-or-yield,g-doc-args
2044  """Create a weighted sum of a categorical column for linear_model.
2045
2046  Note to maintainer: As implementation details, the weighted sum is
2047  implemented via embedding_lookup_sparse toward efficiency. Mathematically,
2048  they are the same.
2049
2050  To be specific, conceptually, categorical column can be treated as multi-hot
2051  vector. Say:
2052
2053  ```python
2054    x = [0 0 1]  # categorical column input
2055    w = [a b c]  # weights
2056  ```
2057  The weighted sum is `c` in this case, which is same as `w[2]`.
2058
2059  Another example is
2060
2061  ```python
2062    x = [0 1 1]  # categorical column input
2063    w = [a b c]  # weights
2064  ```
2065  The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
2066
2067  For both cases, we can implement weighted sum via embedding_lookup with
2068  sparse_combiner = "sum".
2069  """
2070
2071  sparse_tensors = column._get_sparse_tensors(  # pylint: disable=protected-access
2072      builder,
2073      weight_collections=weight_collections,
2074      trainable=trainable)
2075  id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [
2076      array_ops.shape(sparse_tensors.id_tensor)[0], -1
2077  ])
2078  weight_tensor = sparse_tensors.weight_tensor
2079  if weight_tensor is not None:
2080    weight_tensor = sparse_ops.sparse_reshape(
2081        weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
2082
2083  if weight_var is not None:
2084    weight = weight_var
2085  else:
2086    weight = variable_scope.get_variable(
2087        name='weights',
2088        shape=(column._num_buckets, units),  # pylint: disable=protected-access
2089        initializer=init_ops.zeros_initializer(),
2090        trainable=trainable,
2091        collections=weight_collections)
2092  return embedding_ops.safe_embedding_lookup_sparse(
2093      weight,
2094      id_tensor,
2095      sparse_weights=weight_tensor,
2096      combiner=sparse_combiner,
2097      name='weighted_sum')
2098
2099
2100class _SequenceDenseColumn(_FeatureColumn):
2101  """Represents dense sequence data."""
2102
2103  TensorSequenceLengthPair = collections.namedtuple(  # pylint: disable=invalid-name
2104      'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length'])
2105
2106  @abc.abstractmethod
2107  def _get_sequence_dense_tensor(
2108      self, inputs, weight_collections=None, trainable=None):
2109    """Returns a `TensorSequenceLengthPair`."""
2110    pass
2111
2112
2113class _LazyBuilder(object):
2114  """Handles caching of transformations while building the model.
2115
2116  `_FeatureColumn` specifies how to digest an input column to the network. Some
2117  feature columns require data transformations. This class caches those
2118  transformations.
2119
2120  Some features may be used in more than one place. For example, one can use a
2121  bucketized feature by itself and a cross with it. In that case we
2122  should create only one bucketization op instead of creating ops for each
2123  feature column separately. To handle re-use of transformed columns,
2124  `_LazyBuilder` caches all previously transformed columns.
2125
2126  Example:
2127  We're trying to use the following `_FeatureColumn`s:
2128
2129  ```python
2130  bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
2131  keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
2132  age_X_keywords = fc.crossed_column([bucketized_age, "keywords"])
2133  ... = linear_model(features,
2134                          [bucketized_age, keywords, age_X_keywords]
2135  ```
2136
2137  If we transform each column independently, then we'll get duplication of
2138  bucketization (one for cross, one for bucketization itself).
2139  The `_LazyBuilder` eliminates this duplication.
2140  """
2141
2142  def __init__(self, features):
2143    """Creates a `_LazyBuilder`.
2144
2145    Args:
2146      features: A mapping from feature column to objects that are `Tensor` or
2147        `SparseTensor`, or can be converted to same via
2148        `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
2149        signifies a base feature (not-transformed). A `_FeatureColumn` key
2150        means that this `Tensor` is the output of an existing `_FeatureColumn`
2151        which can be reused.
2152    """
2153    self._features = features.copy()
2154    self._feature_tensors = {}
2155
2156  def get(self, key):
2157    """Returns a `Tensor` for the given key.
2158
2159    A `str` key is used to access a base feature (not-transformed). When a
2160    `_FeatureColumn` is passed, the transformed feature is returned if it
2161    already exists, otherwise the given `_FeatureColumn` is asked to provide its
2162    transformed output, which is then cached.
2163
2164    Args:
2165      key: a `str` or a `_FeatureColumn`.
2166
2167    Returns:
2168      The transformed `Tensor` corresponding to the `key`.
2169
2170    Raises:
2171      ValueError: if key is not found or a transformed `Tensor` cannot be
2172        computed.
2173    """
2174    if key in self._feature_tensors:
2175      # FeatureColumn is already transformed or converted.
2176      return self._feature_tensors[key]
2177
2178    if key in self._features:
2179      feature_tensor = self._get_raw_feature_as_tensor(key)
2180      self._feature_tensors[key] = feature_tensor
2181      return feature_tensor
2182
2183    if isinstance(key, six.string_types):
2184      raise ValueError('Feature {} is not in features dictionary.'.format(key))
2185
2186    if not isinstance(key, _FeatureColumn):
2187      raise TypeError('"key" must be either a "str" or "_FeatureColumn". '
2188                      'Provided: {}'.format(key))
2189
2190    column = key
2191    logging.debug('Transforming feature_column %s.', column)
2192    transformed = column._transform_feature(self)  # pylint: disable=protected-access
2193    if transformed is None:
2194      raise ValueError('Column {} is not supported.'.format(column.name))
2195    self._feature_tensors[column] = transformed
2196    return transformed
2197
2198  def _get_raw_feature_as_tensor(self, key):
2199    """Gets the raw_feature (keyed by `key`) as `tensor`.
2200
2201    The raw feature is converted to (sparse) tensor and maybe expand dim.
2202
2203    For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
2204    the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
2205    error out as it is not supported.
2206
2207    Args:
2208      key: A `str` key to access the raw feature.
2209
2210    Returns:
2211      A `Tensor` or `SparseTensor`.
2212
2213    Raises:
2214      ValueError: if the raw feature has rank 0.
2215    """
2216    raw_feature = self._features[key]
2217    feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2218        raw_feature)
2219
2220    def expand_dims(input_tensor):
2221      # Input_tensor must have rank 1.
2222      if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2223        return sparse_ops.sparse_reshape(
2224            input_tensor, [array_ops.shape(input_tensor)[0], 1])
2225      else:
2226        return array_ops.expand_dims(input_tensor, -1)
2227
2228    rank = feature_tensor.get_shape().ndims
2229    if rank is not None:
2230      if rank == 0:
2231        raise ValueError(
2232            'Feature (key: {}) cannot have rank 0. Given: {}'.format(
2233                key, feature_tensor))
2234      return feature_tensor if rank != 1 else expand_dims(feature_tensor)
2235
2236    # Handle dynamic rank.
2237    with ops.control_dependencies([
2238        check_ops.assert_positive(
2239            array_ops.rank(feature_tensor),
2240            message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
2241                key, feature_tensor))]):
2242      return control_flow_ops.cond(
2243          math_ops.equal(1, array_ops.rank(feature_tensor)),
2244          lambda: expand_dims(feature_tensor),
2245          lambda: feature_tensor)
2246
2247
2248# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
2249def _shape_offsets(shape):
2250  """Returns moving offset for each dimension given shape."""
2251  offsets = []
2252  for dim in reversed(shape):
2253    if offsets:
2254      offsets.append(dim * offsets[-1])
2255    else:
2256      offsets.append(dim)
2257  offsets.reverse()
2258  return offsets
2259
2260
2261# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
2262def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
2263  """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
2264
2265  If `input_tensor` is already a `SparseTensor`, just return it.
2266
2267  Args:
2268    input_tensor: A string or integer `Tensor`.
2269    ignore_value: Entries in `dense_tensor` equal to this value will be
2270      absent from the resulting `SparseTensor`. If `None`, default value of
2271      `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
2272
2273  Returns:
2274    A `SparseTensor` with the same shape as `input_tensor`.
2275
2276  Raises:
2277    ValueError: when `input_tensor`'s rank is `None`.
2278  """
2279  input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2280      input_tensor)
2281  if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2282    return input_tensor
2283  with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
2284    if ignore_value is None:
2285      if input_tensor.dtype == dtypes.string:
2286        # Exception due to TF strings are converted to numpy objects by default.
2287        ignore_value = ''
2288      elif input_tensor.dtype.is_integer:
2289        ignore_value = -1  # -1 has a special meaning of missing feature
2290      else:
2291        # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
2292        # constructing a new numpy object of the given type, which yields the
2293        # default value for that type.
2294        ignore_value = input_tensor.dtype.as_numpy_dtype()
2295    ignore_value = math_ops.cast(
2296        ignore_value, input_tensor.dtype, name='ignore_value')
2297    indices = array_ops.where(
2298        math_ops.not_equal(input_tensor, ignore_value), name='indices')
2299    return sparse_tensor_lib.SparseTensor(
2300        indices=indices,
2301        values=array_ops.gather_nd(input_tensor, indices, name='values'),
2302        dense_shape=array_ops.shape(
2303            input_tensor, out_type=dtypes.int64, name='dense_shape'))
2304
2305
2306def _normalize_feature_columns(feature_columns):
2307  """Normalizes the `feature_columns` input.
2308
2309  This method converts the `feature_columns` to list type as best as it can. In
2310  addition, verifies the type and other parts of feature_columns, required by
2311  downstream library.
2312
2313  Args:
2314    feature_columns: The raw feature columns, usually passed by users.
2315
2316  Returns:
2317    The normalized feature column list.
2318
2319  Raises:
2320    ValueError: for any invalid inputs, such as empty, duplicated names, etc.
2321  """
2322  if isinstance(feature_columns, _FeatureColumn):
2323    feature_columns = [feature_columns]
2324
2325  if isinstance(feature_columns, collections_abc.Iterator):
2326    feature_columns = list(feature_columns)
2327
2328  if isinstance(feature_columns, dict):
2329    raise ValueError('Expected feature_columns to be iterable, found dict.')
2330
2331  for column in feature_columns:
2332    if not isinstance(column, _FeatureColumn):
2333      raise ValueError('Items of feature_columns must be a _FeatureColumn. '
2334                       'Given (type {}): {}.'.format(type(column), column))
2335  if not feature_columns:
2336    raise ValueError('feature_columns must not be empty.')
2337  name_to_column = {}
2338  for column in feature_columns:
2339    if column.name in name_to_column:
2340      raise ValueError('Duplicate feature column name found for columns: {} '
2341                       'and {}. This usually means that these columns refer to '
2342                       'same base feature. Either one must be discarded or a '
2343                       'duplicated but renamed item must be inserted in '
2344                       'features dict.'.format(column,
2345                                               name_to_column[column.name]))
2346    name_to_column[column.name] = column
2347
2348  return feature_columns
2349
2350
2351class _NumericColumn(_DenseColumn,
2352                     collections.namedtuple('_NumericColumn', [
2353                         'key', 'shape', 'default_value', 'dtype',
2354                         'normalizer_fn'
2355                     ])):
2356  """see `numeric_column`."""
2357
2358  @property
2359  def name(self):
2360    return self.key
2361
2362  @property
2363  def _parse_example_spec(self):
2364    return {
2365        self.key:
2366            parsing_ops.FixedLenFeature(self.shape, self.dtype,
2367                                        self.default_value)
2368    }
2369
2370  def _transform_feature(self, inputs):
2371    input_tensor = inputs.get(self.key)
2372    if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2373      raise ValueError(
2374          'The corresponding Tensor of numerical column must be a Tensor. '
2375          'SparseTensor is not supported. key: {}'.format(self.key))
2376    if self.normalizer_fn is not None:
2377      input_tensor = self.normalizer_fn(input_tensor)
2378    return math_ops.cast(input_tensor, dtypes.float32)
2379
2380  @property
2381  def _variable_shape(self):
2382    return tensor_shape.TensorShape(self.shape)
2383
2384  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2385    """Returns dense `Tensor` representing numeric feature.
2386
2387    Args:
2388      inputs: A `_LazyBuilder` object to access inputs.
2389      weight_collections: Unused `weight_collections` since no variables are
2390        created in this function.
2391      trainable: Unused `trainable` bool since no variables are created in
2392        this function.
2393
2394    Returns:
2395      Dense `Tensor` created within `_transform_feature`.
2396    """
2397    # Do nothing with weight_collections and trainable since no variables are
2398    # created in this function.
2399    del weight_collections
2400    del trainable
2401    # Feature has been already transformed. Return the intermediate
2402    # representation created by _transform_feature.
2403    return inputs.get(self)
2404
2405
2406class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
2407                        collections.namedtuple('_BucketizedColumn', [
2408                            'source_column', 'boundaries'])):
2409  """See `bucketized_column`."""
2410
2411  @property
2412  def name(self):
2413    return '{}_bucketized'.format(self.source_column.name)
2414
2415  @property
2416  def _parse_example_spec(self):
2417    return self.source_column._parse_example_spec  # pylint: disable=protected-access
2418
2419  def _transform_feature(self, inputs):
2420    source_tensor = inputs.get(self.source_column)
2421    return math_ops._bucketize(  # pylint: disable=protected-access
2422        source_tensor,
2423        boundaries=self.boundaries)
2424
2425  @property
2426  def _variable_shape(self):
2427    return tensor_shape.TensorShape(
2428        tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
2429
2430  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2431    del weight_collections
2432    del trainable
2433    input_tensor = inputs.get(self)
2434    return array_ops.one_hot(
2435        indices=math_ops.cast(input_tensor, dtypes.int64),
2436        depth=len(self.boundaries) + 1,
2437        on_value=1.,
2438        off_value=0.)
2439
2440  @property
2441  def _num_buckets(self):
2442    # By construction, source_column is always one-dimensional.
2443    return (len(self.boundaries) + 1) * self.source_column.shape[0]
2444
2445  def _get_sparse_tensors(self, inputs, weight_collections=None,
2446                          trainable=None):
2447    """Converts dense inputs to SparseTensor so downstream code can use it."""
2448    input_tensor = inputs.get(self)
2449    batch_size = array_ops.shape(input_tensor)[0]
2450    # By construction, source_column is always one-dimensional.
2451    source_dimension = self.source_column.shape[0]
2452
2453    i1 = array_ops.reshape(
2454        array_ops.tile(
2455            array_ops.expand_dims(math_ops.range(0, batch_size), 1),
2456            [1, source_dimension]),
2457        (-1,))
2458    i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
2459    # Flatten the bucket indices and unique them across dimensions
2460    # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
2461    bucket_indices = (
2462        array_ops.reshape(input_tensor, (-1,)) +
2463        (len(self.boundaries) + 1) * i2)
2464
2465    indices = math_ops.cast(
2466        array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64)
2467    dense_shape = math_ops.cast(
2468        array_ops.stack([batch_size, source_dimension]), dtypes.int64)
2469    sparse_tensor = sparse_tensor_lib.SparseTensor(
2470        indices=indices,
2471        values=bucket_indices,
2472        dense_shape=dense_shape)
2473    return _CategoricalColumn.IdWeightPair(sparse_tensor, None)
2474
2475
2476class _EmbeddingColumn(
2477    _DenseColumn, _SequenceDenseColumn,
2478    collections.namedtuple(
2479        '_EmbeddingColumn',
2480        ('categorical_column', 'dimension', 'combiner', 'layer_creator',
2481         'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable',
2482         'use_safe_embedding_lookup'))):
2483  """See `embedding_column`."""
2484
2485  def __new__(cls,
2486              categorical_column,
2487              dimension,
2488              combiner,
2489              layer_creator,
2490              ckpt_to_load_from,
2491              tensor_name_in_ckpt,
2492              max_norm,
2493              trainable,
2494              use_safe_embedding_lookup=True):
2495    return super(_EmbeddingColumn, cls).__new__(
2496        cls,
2497        categorical_column=categorical_column,
2498        dimension=dimension,
2499        combiner=combiner,
2500        layer_creator=layer_creator,
2501        ckpt_to_load_from=ckpt_to_load_from,
2502        tensor_name_in_ckpt=tensor_name_in_ckpt,
2503        max_norm=max_norm,
2504        trainable=trainable,
2505        use_safe_embedding_lookup=use_safe_embedding_lookup)
2506
2507  @property
2508  def name(self):
2509    if not hasattr(self, '_name'):
2510      self._name = '{}_embedding'.format(self.categorical_column.name)
2511    return self._name
2512
2513  @property
2514  def _parse_example_spec(self):
2515    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
2516
2517  def _transform_feature(self, inputs):
2518    return inputs.get(self.categorical_column)
2519
2520  @property
2521  def _variable_shape(self):
2522    if not hasattr(self, '_shape'):
2523      self._shape = tensor_shape.TensorShape([self.dimension])
2524    return self._shape
2525
2526  def _get_dense_tensor_internal(self,
2527                                 inputs,
2528                                 weight_collections=None,
2529                                 trainable=None):
2530    """Private method that follows the signature of _get_dense_tensor."""
2531    # Get sparse IDs and weights.
2532    sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
2533        inputs, weight_collections=weight_collections, trainable=trainable)
2534    sparse_ids = sparse_tensors.id_tensor
2535    sparse_weights = sparse_tensors.weight_tensor
2536
2537    embedding_weights = self.layer_creator(
2538        weight_collections=weight_collections,
2539        scope=variable_scope.get_variable_scope())
2540
2541    if self.ckpt_to_load_from is not None:
2542      to_restore = embedding_weights
2543      if isinstance(to_restore, variables.PartitionedVariable):
2544        to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
2545      checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
2546          self.tensor_name_in_ckpt: to_restore
2547      })
2548
2549    sparse_id_rank = tensor_shape.dimension_value(
2550        sparse_ids.dense_shape.get_shape()[0])
2551    embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
2552    if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
2553        sparse_id_rank <= 2):
2554      embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
2555    # Return embedding lookup result.
2556    return embedding_lookup_sparse(
2557        embedding_weights,
2558        sparse_ids,
2559        sparse_weights,
2560        combiner=self.combiner,
2561        name='%s_weights' % self.name,
2562        max_norm=self.max_norm)
2563
2564  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2565    if isinstance(self.categorical_column, _SequenceCategoricalColumn):
2566      raise ValueError(
2567          'In embedding_column: {}. '
2568          'categorical_column must not be of type _SequenceCategoricalColumn. '
2569          'Suggested fix A: If you wish to use input_layer, use a '
2570          'non-sequence categorical_column_with_*. '
2571          'Suggested fix B: If you wish to create sequence input, use '
2572          'sequence_input_layer instead of input_layer. '
2573          'Given (type {}): {}'.format(
2574              self.name, type(self.categorical_column),
2575              self.categorical_column))
2576    return self._get_dense_tensor_internal(
2577        inputs=inputs,
2578        weight_collections=weight_collections,
2579        trainable=trainable)
2580
2581  def _get_sequence_dense_tensor(
2582      self, inputs, weight_collections=None, trainable=None):
2583    if not isinstance(self.categorical_column, _SequenceCategoricalColumn):
2584      raise ValueError(
2585          'In embedding_column: {}. '
2586          'categorical_column must be of type _SequenceCategoricalColumn '
2587          'to use sequence_input_layer. '
2588          'Suggested fix: Use one of sequence_categorical_column_with_*. '
2589          'Given (type {}): {}'.format(
2590              self.name, type(self.categorical_column),
2591              self.categorical_column))
2592    dense_tensor = self._get_dense_tensor_internal(  # pylint: disable=protected-access
2593        inputs=inputs,
2594        weight_collections=weight_collections,
2595        trainable=trainable)
2596
2597    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
2598    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
2599        sparse_tensors.id_tensor)
2600    return _SequenceDenseColumn.TensorSequenceLengthPair(
2601        dense_tensor=dense_tensor, sequence_length=sequence_length)
2602
2603
2604def _get_graph_for_variable(var):
2605  if isinstance(var, variables.PartitionedVariable):
2606    return list(var)[0].graph
2607  else:
2608    return var.graph
2609
2610
2611class _SharedEmbeddingColumn(
2612    _DenseColumn, _SequenceDenseColumn,
2613    collections.namedtuple(
2614        '_SharedEmbeddingColumn',
2615        ('categorical_column', 'dimension', 'combiner', 'initializer',
2616         'shared_embedding_collection_name', 'ckpt_to_load_from',
2617         'tensor_name_in_ckpt', 'max_norm', 'trainable',
2618         'use_safe_embedding_lookup'))):
2619  """See `embedding_column`."""
2620
2621  @property
2622  def name(self):
2623    if not hasattr(self, '_name'):
2624      self._name = '{}_shared_embedding'.format(self.categorical_column.name)
2625    return self._name
2626
2627  @property
2628  def _var_scope_name(self):
2629    return self.shared_embedding_collection_name
2630
2631  @property
2632  def _parse_example_spec(self):
2633    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
2634
2635  def _transform_feature(self, inputs):
2636    return inputs.get(self.categorical_column)
2637
2638  @property
2639  def _variable_shape(self):
2640    if not hasattr(self, '_shape'):
2641      self._shape = tensor_shape.TensorShape([self.dimension])
2642    return self._shape
2643
2644  def _get_dense_tensor_internal(self,
2645                                 inputs,
2646                                 weight_collections=None,
2647                                 trainable=None):
2648    """Private method that follows the signature of _get_dense_tensor."""
2649    # This method is called from a variable_scope with name _var_scope_name,
2650    # which is shared among all shared embeddings. Open a name_scope here, so
2651    # that the ops for different columns have distinct names.
2652    with ops.name_scope(None, default_name=self.name):
2653      # Get sparse IDs and weights.
2654      sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
2655          inputs, weight_collections=weight_collections, trainable=trainable)
2656      sparse_ids = sparse_tensors.id_tensor
2657      sparse_weights = sparse_tensors.weight_tensor
2658
2659      embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
2660      shared_embedding_collection = ops.get_collection(
2661          self.shared_embedding_collection_name)
2662      if shared_embedding_collection:
2663        if len(shared_embedding_collection) > 1:
2664          raise ValueError(
2665              'Collection {} can only contain one variable. '
2666              'Suggested fix A: Choose a unique name for this collection. '
2667              'Suggested fix B: Do not add any variables to this collection. '
2668              'The feature_column library already adds a variable under the '
2669              'hood.'.format(shared_embedding_collection))
2670        embedding_weights = shared_embedding_collection[0]
2671        if embedding_weights.get_shape() != embedding_shape:
2672          raise ValueError(
2673              'Shared embedding collection {} contains variable {} of '
2674              'unexpected shape {}. Expected shape is {}. '
2675              'Suggested fix A: Choose a unique name for this collection. '
2676              'Suggested fix B: Do not add any variables to this collection. '
2677              'The feature_column library already adds a variable under the '
2678              'hood.'.format(self.shared_embedding_collection_name,
2679                             embedding_weights.name,
2680                             embedding_weights.get_shape(), embedding_shape))
2681      else:
2682        embedding_weights = variable_scope.get_variable(
2683            name='embedding_weights',
2684            shape=embedding_shape,
2685            dtype=dtypes.float32,
2686            initializer=self.initializer,
2687            trainable=self.trainable and trainable,
2688            collections=weight_collections)
2689        ops.add_to_collection(self.shared_embedding_collection_name,
2690                              embedding_weights)
2691      if self.ckpt_to_load_from is not None:
2692        to_restore = embedding_weights
2693        if isinstance(to_restore, variables.PartitionedVariable):
2694          to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
2695        checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
2696            self.tensor_name_in_ckpt: to_restore
2697        })
2698
2699      sparse_id_rank = tensor_shape.dimension_value(
2700          sparse_ids.dense_shape.get_shape()[0])
2701      embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
2702      if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
2703          sparse_id_rank <= 2):
2704        embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
2705      # Return embedding lookup result.
2706      return embedding_lookup_sparse(
2707          embedding_weights,
2708          sparse_ids,
2709          sparse_weights,
2710          combiner=self.combiner,
2711          name='%s_weights' % self.name,
2712          max_norm=self.max_norm)
2713
2714  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2715    if isinstance(self.categorical_column, _SequenceCategoricalColumn):
2716      raise ValueError(
2717          'In embedding_column: {}. '
2718          'categorical_column must not be of type _SequenceCategoricalColumn. '
2719          'Suggested fix A: If you wish to use input_layer, use a '
2720          'non-sequence categorical_column_with_*. '
2721          'Suggested fix B: If you wish to create sequence input, use '
2722          'sequence_input_layer instead of input_layer. '
2723          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
2724                                       self.categorical_column))
2725    return self._get_dense_tensor_internal(
2726        inputs=inputs,
2727        weight_collections=weight_collections,
2728        trainable=trainable)
2729
2730  def _get_sequence_dense_tensor(self,
2731                                 inputs,
2732                                 weight_collections=None,
2733                                 trainable=None):
2734    if not isinstance(self.categorical_column, _SequenceCategoricalColumn):
2735      raise ValueError(
2736          'In embedding_column: {}. '
2737          'categorical_column must be of type _SequenceCategoricalColumn '
2738          'to use sequence_input_layer. '
2739          'Suggested fix: Use one of sequence_categorical_column_with_*. '
2740          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
2741                                       self.categorical_column))
2742    dense_tensor = self._get_dense_tensor_internal(  # pylint: disable=protected-access
2743        inputs=inputs,
2744        weight_collections=weight_collections,
2745        trainable=trainable)
2746    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
2747    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
2748        sparse_tensors.id_tensor)
2749    return _SequenceDenseColumn.TensorSequenceLengthPair(
2750        dense_tensor=dense_tensor, sequence_length=sequence_length)
2751
2752
2753def _check_shape(shape, key):
2754  """Returns shape if it's valid, raises error otherwise."""
2755  assert shape is not None
2756  if not nest.is_sequence(shape):
2757    shape = [shape]
2758  shape = tuple(shape)
2759  for dimension in shape:
2760    if not isinstance(dimension, six.integer_types):
2761      raise TypeError('shape dimensions must be integer. '
2762                      'shape: {}, key: {}'.format(shape, key))
2763    if dimension < 1:
2764      raise ValueError('shape dimensions must be greater than 0. '
2765                       'shape: {}, key: {}'.format(shape, key))
2766  return shape
2767
2768
2769class _HashedCategoricalColumn(
2770    _CategoricalColumn,
2771    collections.namedtuple('_HashedCategoricalColumn',
2772                           ['key', 'hash_bucket_size', 'dtype'])):
2773  """see `categorical_column_with_hash_bucket`."""
2774
2775  @property
2776  def name(self):
2777    return self.key
2778
2779  @property
2780  def _parse_example_spec(self):
2781    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
2782
2783  def _transform_feature(self, inputs):
2784    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
2785    if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2786      raise ValueError('SparseColumn input must be a SparseTensor.')
2787
2788    fc_utils.assert_string_or_int(
2789        input_tensor.dtype,
2790        prefix='column_name: {} input_tensor'.format(self.key))
2791
2792    if self.dtype.is_integer != input_tensor.dtype.is_integer:
2793      raise ValueError(
2794          'Column dtype and SparseTensors dtype must be compatible. '
2795          'key: {}, column dtype: {}, tensor dtype: {}'.format(
2796              self.key, self.dtype, input_tensor.dtype))
2797
2798    if self.dtype == dtypes.string:
2799      sparse_values = input_tensor.values
2800    else:
2801      sparse_values = string_ops.as_string(input_tensor.values)
2802
2803    sparse_id_values = string_ops.string_to_hash_bucket_fast(
2804        sparse_values, self.hash_bucket_size, name='lookup')
2805    return sparse_tensor_lib.SparseTensor(
2806        input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
2807
2808  @property
2809  def _num_buckets(self):
2810    """Returns number of buckets in this sparse feature."""
2811    return self.hash_bucket_size
2812
2813  def _get_sparse_tensors(self, inputs, weight_collections=None,
2814                          trainable=None):
2815    return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
2816
2817
2818class _VocabularyFileCategoricalColumn(
2819    _CategoricalColumn,
2820    collections.namedtuple('_VocabularyFileCategoricalColumn', (
2821        'key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype',
2822        'default_value'
2823    ))):
2824  """See `categorical_column_with_vocabulary_file`."""
2825
2826  @property
2827  def name(self):
2828    return self.key
2829
2830  @property
2831  def _parse_example_spec(self):
2832    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
2833
2834  def _transform_feature(self, inputs):
2835    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
2836
2837    if self.dtype.is_integer != input_tensor.dtype.is_integer:
2838      raise ValueError(
2839          'Column dtype and SparseTensors dtype must be compatible. '
2840          'key: {}, column dtype: {}, tensor dtype: {}'.format(
2841              self.key, self.dtype, input_tensor.dtype))
2842
2843    fc_utils.assert_string_or_int(
2844        input_tensor.dtype,
2845        prefix='column_name: {} input_tensor'.format(self.key))
2846
2847    key_dtype = self.dtype
2848    if input_tensor.dtype.is_integer:
2849      # `index_table_from_file` requires 64-bit integer keys.
2850      key_dtype = dtypes.int64
2851      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
2852
2853    return lookup_ops.index_table_from_file(
2854        vocabulary_file=self.vocabulary_file,
2855        num_oov_buckets=self.num_oov_buckets,
2856        vocab_size=self.vocabulary_size,
2857        default_value=self.default_value,
2858        key_dtype=key_dtype,
2859        name='{}_lookup'.format(self.key)).lookup(input_tensor)
2860
2861  @property
2862  def _num_buckets(self):
2863    """Returns number of buckets in this sparse feature."""
2864    return self.vocabulary_size + self.num_oov_buckets
2865
2866  def _get_sparse_tensors(
2867      self, inputs, weight_collections=None, trainable=None):
2868    return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
2869
2870
2871class _VocabularyListCategoricalColumn(
2872    _CategoricalColumn,
2873    collections.namedtuple('_VocabularyListCategoricalColumn', (
2874        'key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'
2875    ))):
2876  """See `categorical_column_with_vocabulary_list`."""
2877
2878  @property
2879  def name(self):
2880    return self.key
2881
2882  @property
2883  def _parse_example_spec(self):
2884    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
2885
2886  def _transform_feature(self, inputs):
2887    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
2888
2889    if self.dtype.is_integer != input_tensor.dtype.is_integer:
2890      raise ValueError(
2891          'Column dtype and SparseTensors dtype must be compatible. '
2892          'key: {}, column dtype: {}, tensor dtype: {}'.format(
2893              self.key, self.dtype, input_tensor.dtype))
2894
2895    fc_utils.assert_string_or_int(
2896        input_tensor.dtype,
2897        prefix='column_name: {} input_tensor'.format(self.key))
2898
2899    key_dtype = self.dtype
2900    if input_tensor.dtype.is_integer:
2901      # `index_table_from_tensor` requires 64-bit integer keys.
2902      key_dtype = dtypes.int64
2903      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
2904
2905    return lookup_ops.index_table_from_tensor(
2906        vocabulary_list=tuple(self.vocabulary_list),
2907        default_value=self.default_value,
2908        num_oov_buckets=self.num_oov_buckets,
2909        dtype=key_dtype,
2910        name='{}_lookup'.format(self.key)).lookup(input_tensor)
2911
2912  @property
2913  def _num_buckets(self):
2914    """Returns number of buckets in this sparse feature."""
2915    return len(self.vocabulary_list) + self.num_oov_buckets
2916
2917  def _get_sparse_tensors(
2918      self, inputs, weight_collections=None, trainable=None):
2919    return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
2920
2921
2922class _IdentityCategoricalColumn(
2923    _CategoricalColumn,
2924    collections.namedtuple('_IdentityCategoricalColumn', (
2925        'key', 'num_buckets', 'default_value'
2926    ))):
2927
2928  """See `categorical_column_with_identity`."""
2929
2930  @property
2931  def name(self):
2932    return self.key
2933
2934  @property
2935  def _parse_example_spec(self):
2936    return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
2937
2938  def _transform_feature(self, inputs):
2939    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
2940
2941    if not input_tensor.dtype.is_integer:
2942      raise ValueError(
2943          'Invalid input, not integer. key: {} dtype: {}'.format(
2944              self.key, input_tensor.dtype))
2945    values = input_tensor.values
2946    if input_tensor.values.dtype != dtypes.int64:
2947      values = math_ops.cast(values, dtypes.int64, name='values')
2948    if self.default_value is not None:
2949      num_buckets = math_ops.cast(
2950          self.num_buckets, dtypes.int64, name='num_buckets')
2951      zero = math_ops.cast(0, dtypes.int64, name='zero')
2952      # Assign default for out-of-range values.
2953      values = array_ops.where(
2954          math_ops.logical_or(
2955              values < zero, values >= num_buckets, name='out_of_range'),
2956          array_ops.fill(
2957              dims=array_ops.shape(values),
2958              value=math_ops.cast(self.default_value, dtypes.int64),
2959              name='default_values'), values)
2960    return sparse_tensor_lib.SparseTensor(
2961        indices=input_tensor.indices,
2962        values=values,
2963        dense_shape=input_tensor.dense_shape)
2964
2965  @property
2966  def _num_buckets(self):
2967    """Returns number of buckets in this sparse feature."""
2968    return self.num_buckets
2969
2970  def _get_sparse_tensors(
2971      self, inputs, weight_collections=None, trainable=None):
2972    return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
2973
2974
2975class _WeightedCategoricalColumn(
2976    _CategoricalColumn,
2977    collections.namedtuple('_WeightedCategoricalColumn', (
2978        'categorical_column', 'weight_feature_key', 'dtype'
2979    ))):
2980  """See `weighted_categorical_column`."""
2981
2982  @property
2983  def name(self):
2984    return '{}_weighted_by_{}'.format(
2985        self.categorical_column.name, self.weight_feature_key)
2986
2987  @property
2988  def _parse_example_spec(self):
2989    config = self.categorical_column._parse_example_spec  # pylint: disable=protected-access
2990    if self.weight_feature_key in config:
2991      raise ValueError('Parse config {} already exists for {}.'.format(
2992          config[self.weight_feature_key], self.weight_feature_key))
2993    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
2994    return config
2995
2996  @property
2997  def _num_buckets(self):
2998    return self.categorical_column._num_buckets  # pylint: disable=protected-access
2999
3000  def _transform_feature(self, inputs):
3001    weight_tensor = inputs.get(self.weight_feature_key)
3002    if weight_tensor is None:
3003      raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
3004    weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
3005        weight_tensor)
3006    if self.dtype != weight_tensor.dtype.base_dtype:
3007      raise ValueError('Bad dtype, expected {}, but got {}.'.format(
3008          self.dtype, weight_tensor.dtype))
3009    if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
3010      # The weight tensor can be a regular Tensor. In this case, sparsify it.
3011      weight_tensor = _to_sparse_input_and_drop_ignore_values(
3012          weight_tensor, ignore_value=0.0)
3013    if not weight_tensor.dtype.is_floating:
3014      weight_tensor = math_ops.cast(weight_tensor, dtypes.float32)
3015    return (inputs.get(self.categorical_column), weight_tensor)
3016
3017  def _get_sparse_tensors(
3018      self, inputs, weight_collections=None, trainable=None):
3019    del weight_collections
3020    del trainable
3021    tensors = inputs.get(self)
3022    return _CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
3023
3024
3025class _CrossedColumn(
3026    _CategoricalColumn,
3027    collections.namedtuple('_CrossedColumn',
3028                           ['keys', 'hash_bucket_size', 'hash_key'])):
3029  """See `crossed_column`."""
3030
3031  @property
3032  def name(self):
3033    feature_names = []
3034    for key in _collect_leaf_level_keys(self):
3035      if isinstance(key, _FeatureColumn):
3036        feature_names.append(key.name)
3037      else:  # key must be a string
3038        feature_names.append(key)
3039    return '_X_'.join(sorted(feature_names))
3040
3041  @property
3042  def _parse_example_spec(self):
3043    config = {}
3044    for key in self.keys:
3045      if isinstance(key, _FeatureColumn):
3046        config.update(key._parse_example_spec)  # pylint: disable=protected-access
3047      else:  # key must be a string
3048        config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
3049    return config
3050
3051  def _transform_feature(self, inputs):
3052    feature_tensors = []
3053    for key in _collect_leaf_level_keys(self):
3054      if isinstance(key, six.string_types):
3055        feature_tensors.append(inputs.get(key))
3056      elif isinstance(key, _CategoricalColumn):
3057        ids_and_weights = key._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3058        if ids_and_weights.weight_tensor is not None:
3059          raise ValueError(
3060              'crossed_column does not support weight_tensor, but the given '
3061              'column populates weight_tensor. '
3062              'Given column: {}'.format(key.name))
3063        feature_tensors.append(ids_and_weights.id_tensor)
3064      else:
3065        raise ValueError('Unsupported column type. Given: {}'.format(key))
3066    return sparse_ops.sparse_cross_hashed(
3067        inputs=feature_tensors,
3068        num_buckets=self.hash_bucket_size,
3069        hash_key=self.hash_key)
3070
3071  @property
3072  def _num_buckets(self):
3073    """Returns number of buckets in this sparse feature."""
3074    return self.hash_bucket_size
3075
3076  def _get_sparse_tensors(self, inputs, weight_collections=None,
3077                          trainable=None):
3078    return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
3079
3080
3081def _collect_leaf_level_keys(cross):
3082  """Collects base keys by expanding all nested crosses.
3083
3084  Args:
3085    cross: A `_CrossedColumn`.
3086
3087  Returns:
3088    A list of strings or `_CategoricalColumn` instances.
3089  """
3090  leaf_level_keys = []
3091  for k in cross.keys:
3092    if isinstance(k, _CrossedColumn):
3093      leaf_level_keys.extend(_collect_leaf_level_keys(k))
3094    else:
3095      leaf_level_keys.append(k)
3096  return leaf_level_keys
3097
3098
3099class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
3100                       collections.namedtuple('_IndicatorColumn',
3101                                              ['categorical_column'])):
3102  """Represents a one-hot column for use in deep networks.
3103
3104  Args:
3105    categorical_column: A `_CategoricalColumn` which is created by
3106      `categorical_column_with_*` function.
3107  """
3108
3109  @property
3110  def name(self):
3111    return '{}_indicator'.format(self.categorical_column.name)
3112
3113  def _transform_feature(self, inputs):
3114    """Returns dense `Tensor` representing feature.
3115
3116    Args:
3117      inputs: A `_LazyBuilder` object to access inputs.
3118
3119    Returns:
3120      Transformed feature `Tensor`.
3121
3122    Raises:
3123      ValueError: if input rank is not known at graph building time.
3124    """
3125    id_weight_pair = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3126    id_tensor = id_weight_pair.id_tensor
3127    weight_tensor = id_weight_pair.weight_tensor
3128
3129    # If the underlying column is weighted, return the input as a dense tensor.
3130    if weight_tensor is not None:
3131      weighted_column = sparse_ops.sparse_merge(
3132          sp_ids=id_tensor,
3133          sp_values=weight_tensor,
3134          vocab_size=int(self._variable_shape[-1]))
3135      # Remove (?, -1) index.
3136      weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
3137                                                weighted_column.dense_shape)
3138      # Use scatter_nd to merge duplicated indices if existed,
3139      # instead of sparse_tensor_to_dense.
3140      return array_ops.scatter_nd(weighted_column.indices,
3141                                  weighted_column.values,
3142                                  weighted_column.dense_shape)
3143
3144    dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
3145        id_tensor, default_value=-1)
3146
3147    # One hot must be float for tf.concat reasons since all other inputs to
3148    # input_layer are float32.
3149    one_hot_id_tensor = array_ops.one_hot(
3150        dense_id_tensor,
3151        depth=self._variable_shape[-1],
3152        on_value=1.0,
3153        off_value=0.0)
3154
3155    # Reduce to get a multi-hot per example.
3156    return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
3157
3158  @property
3159  def _parse_example_spec(self):
3160    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
3161
3162  @property
3163  def _variable_shape(self):
3164    """Returns a `TensorShape` representing the shape of the dense `Tensor`."""
3165    return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
3166
3167  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3168    """Returns dense `Tensor` representing feature.
3169
3170    Args:
3171      inputs: A `_LazyBuilder` object to access inputs.
3172      weight_collections: Unused `weight_collections` since no variables are
3173        created in this function.
3174      trainable: Unused `trainable` bool since no variables are created in
3175        this function.
3176
3177    Returns:
3178      Dense `Tensor` created within `_transform_feature`.
3179
3180    Raises:
3181      ValueError: If `categorical_column` is a `_SequenceCategoricalColumn`.
3182    """
3183    # Do nothing with weight_collections and trainable since no variables are
3184    # created in this function.
3185    del weight_collections
3186    del trainable
3187    if isinstance(self.categorical_column, _SequenceCategoricalColumn):
3188      raise ValueError(
3189          'In indicator_column: {}. '
3190          'categorical_column must not be of type _SequenceCategoricalColumn. '
3191          'Suggested fix A: If you wish to use input_layer, use a '
3192          'non-sequence categorical_column_with_*. '
3193          'Suggested fix B: If you wish to create sequence input, use '
3194          'sequence_input_layer instead of input_layer. '
3195          'Given (type {}): {}'.format(
3196              self.name, type(self.categorical_column),
3197              self.categorical_column))
3198    # Feature has been already transformed. Return the intermediate
3199    # representation created by _transform_feature.
3200    return inputs.get(self)
3201
3202  def _get_sequence_dense_tensor(
3203      self, inputs, weight_collections=None, trainable=None):
3204    # Do nothing with weight_collections and trainable since no variables are
3205    # created in this function.
3206    del weight_collections
3207    del trainable
3208    if not isinstance(self.categorical_column, _SequenceCategoricalColumn):
3209      raise ValueError(
3210          'In indicator_column: {}. '
3211          'categorical_column must be of type _SequenceCategoricalColumn '
3212          'to use sequence_input_layer. '
3213          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3214          'Given (type {}): {}'.format(
3215              self.name, type(self.categorical_column),
3216              self.categorical_column))
3217    # Feature has been already transformed. Return the intermediate
3218    # representation created by _transform_feature.
3219    dense_tensor = inputs.get(self)
3220    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3221    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3222        sparse_tensors.id_tensor)
3223    return _SequenceDenseColumn.TensorSequenceLengthPair(
3224        dense_tensor=dense_tensor, sequence_length=sequence_length)
3225
3226
3227def _verify_static_batch_size_equality(tensors, columns):
3228  """Validates that the first dim (batch size) of all tensors are equal or None.
3229
3230  Args:
3231    tensors: list of tensors to check.
3232    columns: list of feature columns matching tensors. Will be used for error
3233      messaging.
3234
3235  Raises:
3236    ValueError: if one of the tensors has a variant batch size
3237  """
3238  # bath_size is a tf.compat.v1.Dimension object.
3239  expected_batch_size = None
3240  for i in range(0, len(tensors)):
3241    if tensors[i].shape.dims[0].value is not None:
3242      if expected_batch_size is None:
3243        bath_size_column_index = i
3244        expected_batch_size = tensors[i].shape.dims[0]
3245      elif not expected_batch_size.is_compatible_with(tensors[i].shape.dims[0]):
3246        raise ValueError(
3247            'Batch size (first dimension) of each feature must be same. '
3248            'Batch size of columns ({}, {}): ({}, {})'.format(
3249                columns[bath_size_column_index].name, columns[i].name,
3250                expected_batch_size, tensors[i].shape.dims[0]))
3251
3252
3253class _SequenceCategoricalColumn(
3254    _CategoricalColumn,
3255    collections.namedtuple(
3256        '_SequenceCategoricalColumn', ['categorical_column'])):
3257  """Represents sequences of categorical data."""
3258
3259  @property
3260  def name(self):
3261    return self.categorical_column.name
3262
3263  @property
3264  def _parse_example_spec(self):
3265    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
3266
3267  def _transform_feature(self, inputs):
3268    return self.categorical_column._transform_feature(inputs)  # pylint: disable=protected-access
3269
3270  @property
3271  def _num_buckets(self):
3272    return self.categorical_column._num_buckets  # pylint: disable=protected-access
3273
3274  def _get_sparse_tensors(self, inputs, weight_collections=None,
3275                          trainable=None):
3276    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3277    id_tensor = sparse_tensors.id_tensor
3278    weight_tensor = sparse_tensors.weight_tensor
3279
3280    # Expands third dimension, if necessary so that embeddings are not
3281    # combined during embedding lookup. If the tensor is already 3D, leave
3282    # as-is.
3283    shape = array_ops.shape(id_tensor)
3284    # Compute the third dimension explicitly instead of setting it to -1, as
3285    # that doesn't work for dynamically shaped tensors with 0-length at runtime.
3286    # This happens for empty sequences.
3287    target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
3288    id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
3289    if weight_tensor is not None:
3290      weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
3291
3292    return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
3293