1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction.
16
17FeatureColumns provide a high level abstraction for ingesting and representing
18features. FeatureColumns are also the primary way of encoding features for
19canned `tf.estimator.Estimator`s.
20
21When using FeatureColumns with `Estimators`, the type of feature column you
22should choose depends on (1) the feature type and (2) the model type.
23
241. Feature type:
25
26  * Continuous features can be represented by `numeric_column`.
27  * Categorical features can be represented by any `categorical_column_with_*`
28  column:
29    - `categorical_column_with_vocabulary_list`
30    - `categorical_column_with_vocabulary_file`
31    - `categorical_column_with_hash_bucket`
32    - `categorical_column_with_identity`
33    - `weighted_categorical_column`
34
352. Model type:
36
37  * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
38
39    Continuous features can be directly fed into deep neural network models.
40
41      age_column = numeric_column("age")
42
43    To feed sparse features into DNN models, wrap the column with
44    `embedding_column` or `indicator_column`. `indicator_column` is recommended
45    for features with only a few possible values. For features with many
46    possible values, to reduce the size of your model, `embedding_column` is
47    recommended.
48
49      embedded_dept_column = embedding_column(
50          categorical_column_with_vocabulary_list(
51              "department", ["math", "philosophy", ...]), dimension=10)
52
53  * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
54
55    Sparse features can be fed directly into linear models. They behave like an
56    indicator column but with an efficient implementation.
57
58      dept_column = categorical_column_with_vocabulary_list("department",
59          ["math", "philosophy", "english"])
60
61    It is recommended that continuous features be bucketized before being
62    fed into linear models.
63
64      bucketized_age_column = bucketized_column(
65          source_column=age_column,
66          boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
67
68    Sparse features can be crossed (also known as conjuncted or combined) in
69    order to form non-linearities, and then fed into linear models.
70
71      cross_dept_age_column = crossed_column(
72          columns=["department", bucketized_age_column],
73          hash_bucket_size=1000)
74
75Example of building canned `Estimator`s using FeatureColumns:
76
77  ```python
78  # Define features and transformations
79  deep_feature_columns = [age_column, embedded_dept_column]
80  wide_feature_columns = [dept_column, bucketized_age_column,
81      cross_dept_age_column]
82
83  # Build deep model
84  estimator = DNNClassifier(
85      feature_columns=deep_feature_columns,
86      hidden_units=[500, 250, 50])
87  estimator.train(...)
88
89  # Or build a wide model
90  estimator = LinearClassifier(
91      feature_columns=wide_feature_columns)
92  estimator.train(...)
93
94  # Or build a wide and deep model!
95  estimator = DNNLinearCombinedClassifier(
96      linear_feature_columns=wide_feature_columns,
97      dnn_feature_columns=deep_feature_columns,
98      dnn_hidden_units=[500, 250, 50])
99  estimator.train(...)
100  ```
101
102
103FeatureColumns can also be transformed into a generic input layer for
104custom models using `input_layer`.
105
106Example of building model using FeatureColumns, this can be used in a
107`model_fn` which is given to the {tf.estimator.Estimator}:
108
109  ```python
110  # Building model via layers
111
112  deep_feature_columns = [age_column, embedded_dept_column]
113  columns_to_tensor = parse_feature_columns_from_examples(
114      serialized=my_data,
115      feature_columns=deep_feature_columns)
116  first_layer = input_layer(
117      features=columns_to_tensor,
118      feature_columns=deep_feature_columns)
119  second_layer = fully_connected(first_layer, ...)
120  ```
121
122NOTE: Functions prefixed with "_" indicate experimental or private parts of
123the API subject to change, and should not be relied upon!
124"""
125
126from __future__ import absolute_import
127from __future__ import division
128from __future__ import print_function
129
130import abc
131import collections
132import math
133
134import numpy as np
135import six
136
137
138from tensorflow.python.eager import context
139from tensorflow.python.feature_column import feature_column as fc_old
140from tensorflow.python.feature_column import utils as fc_utils
141from tensorflow.python.framework import dtypes
142from tensorflow.python.framework import ops
143from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
144from tensorflow.python.framework import tensor_shape
145# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
146# of the main repo.
147from tensorflow.python.keras import utils
148from tensorflow.python.keras.engine import training
149from tensorflow.python.keras.engine.base_layer import Layer
150from tensorflow.python.ops import array_ops
151from tensorflow.python.ops import check_ops
152from tensorflow.python.ops import control_flow_ops
153from tensorflow.python.ops import embedding_ops
154from tensorflow.python.ops import init_ops
155from tensorflow.python.ops import lookup_ops
156from tensorflow.python.ops import math_ops
157from tensorflow.python.ops import nn_ops
158from tensorflow.python.ops import parsing_ops
159from tensorflow.python.ops import sparse_ops
160from tensorflow.python.ops import string_ops
161from tensorflow.python.ops import variable_scope
162from tensorflow.python.ops import variables
163from tensorflow.python.platform import gfile
164from tensorflow.python.platform import tf_logging as logging
165from tensorflow.python.training import checkpoint_utils
166from tensorflow.python.training.tracking import tracking
167from tensorflow.python.util import deprecation
168from tensorflow.python.util import nest
169from tensorflow.python.util.tf_export import keras_export
170from tensorflow.python.util.tf_export import tf_export
171
172
173_FEATURE_COLUMN_DEPRECATION_DATE = None
174_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
175                               'deprecated. Please use the new FeatureColumn '
176                               'APIs instead.')
177
178
179class StateManager(object):
180  """Manages the state associated with FeatureColumns.
181
182  Some `FeatureColumn`s create variables or resources to assist their
183  computation. The `StateManager` is responsible for creating and storing these
184  objects since `FeatureColumn`s are supposed to be stateless configuration
185  only.
186  """
187
188  def create_variable(self,
189                      feature_column,
190                      name,
191                      shape,
192                      dtype=None,
193                      trainable=True,
194                      use_resource=True,
195                      initializer=None):
196    """Creates a new variable.
197
198    Args:
199      feature_column: A `FeatureColumn` object this variable corresponds to.
200      name: variable name.
201      shape: variable shape.
202      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
203      trainable: Whether this variable is trainable or not.
204      use_resource: If true, we use resource variables. Otherwise we use
205        RefVariable.
206      initializer: initializer instance (callable).
207
208    Returns:
209      The created variable.
210    """
211    del feature_column, name, shape, dtype, trainable, use_resource, initializer
212    raise NotImplementedError('StateManager.create_variable')
213
214  def add_variable(self, feature_column, var):
215    """Adds an existing variable to the state.
216
217    Args:
218      feature_column: A `FeatureColumn` object to associate this variable with.
219      var: The variable.
220    """
221    del feature_column, var
222    raise NotImplementedError('StateManager.add_variable')
223
224  def get_variable(self, feature_column, name):
225    """Returns an existing variable.
226
227    Args:
228      feature_column: A `FeatureColumn` object this variable corresponds to.
229      name: variable name.
230    """
231    del feature_column, name
232    raise NotImplementedError('StateManager.get_var')
233
234  def add_resource(self, feature_column, name, resource):
235    """Creates a new resource.
236
237    Resources can be things such as tables etc.
238
239    Args:
240      feature_column: A `FeatureColumn` object this resource corresponds to.
241      name: Name of the resource.
242      resource: The resource.
243
244    Returns:
245      The created resource.
246    """
247    del feature_column, name, resource
248    raise NotImplementedError('StateManager.add_resource')
249
250  def get_resource(self, feature_column, name):
251    """Returns an already created resource.
252
253    Resources can be things such as tables etc.
254
255    Args:
256      feature_column: A `FeatureColumn` object this variable corresponds to.
257      name: Name of the resource.
258    """
259    del feature_column, name
260    raise NotImplementedError('StateManager.get_resource')
261
262
263class _StateManagerImpl(StateManager):
264  """Manages the state of DenseFeatures and LinearLayer."""
265
266  def __init__(self, layer, trainable):
267    """Creates an _StateManagerImpl object.
268
269    Args:
270      layer: The input layer this state manager is associated with.
271      trainable: Whether by default, variables created are trainable or not.
272    """
273    self._trainable = trainable
274    self._layer = layer
275    self._cols_to_vars_map = collections.defaultdict(lambda: {})
276
277  def create_variable(self,
278                      feature_column,
279                      name,
280                      shape,
281                      dtype=None,
282                      trainable=True,
283                      use_resource=True,
284                      initializer=None):
285    if name in self._cols_to_vars_map[feature_column]:
286      raise ValueError('Variable already exists.')
287
288    var = self._layer.add_variable(
289        name=name,
290        shape=shape,
291        dtype=dtype,
292        initializer=initializer,
293        trainable=self._trainable and trainable,
294        use_resource=use_resource,
295        # TODO(rohanj): Get rid of this hack once we have a mechanism for
296        # specifying a default partitioner for an entire layer. In that case,
297        # the default getter for Layers should work.
298        getter=variable_scope.get_variable)
299    self._cols_to_vars_map[feature_column][name] = var
300    return var
301
302  def get_variable(self, feature_column, name):
303    if name in self._cols_to_vars_map[feature_column]:
304      return self._cols_to_vars_map[feature_column][name]
305    raise ValueError('Variable does not exist.')
306
307
308class _BaseFeaturesLayer(Layer):
309  """Base class for DenseFeatures and SequenceFeatures.
310
311  Defines common methods and helpers.
312
313  Args:
314    feature_columns: An iterable containing the FeatureColumns to use as
315      inputs to your model.
316    expected_column_type: Expected class for provided feature columns.
317    trainable:  Boolean, whether the layer's variables will be updated via
318      gradient descent during training.
319    name: Name to give to the DenseFeatures.
320    **kwargs: Keyword arguments to construct a layer.
321
322  Raises:
323    ValueError: if an item in `feature_columns` doesn't match
324      `expected_column_type`.
325  """
326  def __init__(self, feature_columns, expected_column_type, trainable, name,
327               **kwargs):
328    super(_BaseFeaturesLayer, self).__init__(
329        name=name, trainable=trainable, **kwargs)
330    self._feature_columns = _normalize_feature_columns(feature_columns)
331    self._state_manager = _StateManagerImpl(self, self.trainable)
332    for column in self._feature_columns:
333      if not isinstance(column, expected_column_type):
334        raise ValueError(
335            'Items of feature_columns must be a {}. '
336            'You can wrap a categorical column with an '
337            'embedding_column or indicator_column. Given: {}'.format(
338                expected_column_type, column))
339
340  def build(self, _):
341    for column in self._feature_columns:
342      with variable_scope._pure_variable_scope(self.name):  # pylint: disable=protected-access
343        with variable_scope._pure_variable_scope(column.name):  # pylint: disable=protected-access
344          column.create_state(self._state_manager)
345    super(_BaseFeaturesLayer, self).build(None)
346
347  def _output_shape(self, input_shape, num_elements):
348    """Computes expected output shape of the layer or a column's dense tensor.
349
350    Args:
351      input_shape: Tensor or array with batch shape.
352      num_elements: Size of the last dimension of the output.
353
354    Returns:
355      Tuple with output shape.
356    """
357    raise NotImplementedError('Calling an abstract method.')
358
359  def compute_output_shape(self, input_shape):
360    total_elements = 0
361    for column in self._feature_columns:
362      total_elements += column.variable_shape.num_elements()
363    return self._target_shape(input_shape, total_elements)
364
365  def _process_dense_tensor(self, column, tensor):
366    """Reshapes the dense tensor output of a column based on expected shape.
367
368    Args:
369      column: A DenseColumn or SequenceDenseColumn object.
370      tensor: A dense tensor obtained from the same column.
371
372    Returns:
373      Reshaped dense tensor."""
374    num_elements = column.variable_shape.num_elements()
375    target_shape = self._target_shape(array_ops.shape(tensor), num_elements)
376    return array_ops.reshape(tensor, shape=target_shape)
377
378  def _verify_and_concat_tensors(self, output_tensors):
379    """Verifies and concatenates the dense output of several columns."""
380    _verify_static_batch_size_equality(output_tensors, self._feature_columns)
381    return array_ops.concat(output_tensors, -1)
382
383
384@keras_export('keras.layers.DenseFeatures')
385class DenseFeatures(_BaseFeaturesLayer):
386  """A layer that produces a dense `Tensor` based on given `feature_columns`.
387
388  Generally a single example in training data is described with FeatureColumns.
389  At the first layer of the model, this column oriented data should be converted
390  to a single `Tensor`.
391
392  This layer can be called multiple times with different features.
393
394  Example:
395
396  ```python
397  price = numeric_column('price')
398  keywords_embedded = embedding_column(
399      categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
400  columns = [price, keywords_embedded, ...]
401  feature_layer = DenseFeatures(columns)
402
403  features = tf.parse_example(..., features=make_parse_example_spec(columns))
404  dense_tensor = feature_layer(features)
405  for units in [128, 64, 32]:
406    dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
407  prediction = tf.layers.dense(dense_tensor, 1).
408  ```
409  """
410
411  def __init__(self,
412               feature_columns,
413               trainable=True,
414               name=None,
415               **kwargs):
416    """Constructs a DenseFeatures.
417
418    Args:
419      feature_columns: An iterable containing the FeatureColumns to use as
420        inputs to your model. All items should be instances of classes derived
421        from `DenseColumn` such as `numeric_column`, `embedding_column`,
422        `bucketized_column`, `indicator_column`. If you have categorical
423        features, you can wrap them with an `embedding_column` or
424        `indicator_column`.
425      trainable:  Boolean, whether the layer's variables will be updated via
426        gradient descent during training.
427      name: Name to give to the DenseFeatures.
428      **kwargs: Keyword arguments to construct a layer.
429
430    Raises:
431      ValueError: if an item in `feature_columns` is not a `DenseColumn`.
432    """
433    super(DenseFeatures, self).__init__(
434        feature_columns=feature_columns,
435        trainable=trainable,
436        name=name,
437        expected_column_type=DenseColumn,
438        **kwargs)
439
440  @property
441  def _is_feature_layer(self):
442    return True
443
444  def _target_shape(self, input_shape, total_elements):
445    return (input_shape[0], total_elements)
446
447  def call(self, features, cols_to_output_tensors=None):
448    """Returns a dense tensor corresponding to the `feature_columns`.
449
450    Args:
451      features: A mapping from key to tensors. `FeatureColumn`s look up via
452        these keys. For example `numeric_column('price')` will look at 'price'
453        key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
454        on corresponding `FeatureColumn`.
455      cols_to_output_tensors: If not `None`, this will be filled with a dict
456        mapping feature columns to output tensors created.
457
458    Returns:
459      A `Tensor` which represents input layer of a model. Its shape
460      is (batch_size, first_layer_dimension) and its dtype is `float32`.
461      first_layer_dimension is determined based on given `feature_columns`.
462
463    Raises:
464      ValueError: If features are not a dictionary.
465    """
466    if not isinstance(features, dict):
467      raise ValueError('We expected a dictionary here. Instead we got: ',
468                       features)
469    transformation_cache = FeatureTransformationCache(features)
470    output_tensors = []
471    for column in self._feature_columns:
472      with ops.name_scope(column.name):
473        tensor = column.get_dense_tensor(transformation_cache,
474                                         self._state_manager)
475        processed_tensors = self._process_dense_tensor(column, tensor)
476        if cols_to_output_tensors is not None:
477          cols_to_output_tensors[column] = processed_tensors
478        output_tensors.append(processed_tensors)
479    return self._verify_and_concat_tensors(output_tensors)
480
481
482class _LinearModelLayer(Layer):
483  """Layer that contains logic for `LinearModel`."""
484
485  def __init__(self,
486               feature_columns,
487               units=1,
488               sparse_combiner='sum',
489               trainable=True,
490               name=None,
491               **kwargs):
492    super(_LinearModelLayer, self).__init__(
493        name=name, trainable=trainable, **kwargs)
494
495    self._feature_columns = _normalize_feature_columns(feature_columns)
496    for column in self._feature_columns:
497      if not isinstance(column, (DenseColumn, CategoricalColumn)):
498        raise ValueError(
499            'Items of feature_columns must be either a '
500            'DenseColumn or CategoricalColumn. Given: {}'.format(column))
501
502    self._units = units
503    self._sparse_combiner = sparse_combiner
504
505    self._state_manager = _StateManagerImpl(self, self.trainable)
506    self.bias = None
507
508  def build(self, _):
509    # We need variable scopes for now because we want the variable partitioning
510    # information to percolate down. We also use _pure_variable_scope's here
511    # since we want to open up a name_scope in the `call` method while creating
512    # the ops.
513    with variable_scope._pure_variable_scope(self.name):  # pylint: disable=protected-access
514      for column in self._feature_columns:
515        with variable_scope._pure_variable_scope(column.name):  # pylint: disable=protected-access
516          # Create the state for each feature column
517          column.create_state(self._state_manager)
518
519          # Create a weight variable for each column.
520          if isinstance(column, CategoricalColumn):
521            first_dim = column.num_buckets
522          else:
523            first_dim = column.variable_shape.num_elements()
524          self._state_manager.create_variable(
525              column,
526              name='weights',
527              dtype=dtypes.float32,
528              shape=(first_dim, self._units),
529              initializer=init_ops.zeros_initializer(),
530              trainable=self.trainable)
531
532      # Create a bias variable.
533      self.bias = self.add_variable(
534          name='bias_weights',
535          dtype=dtypes.float32,
536          shape=[self._units],
537          initializer=init_ops.zeros_initializer(),
538          trainable=self.trainable,
539          use_resource=True,
540          # TODO(rohanj): Get rid of this hack once we have a mechanism for
541          # specifying a default partitioner for an entire layer. In that case,
542          # the default getter for Layers should work.
543          getter=variable_scope.get_variable)
544
545    super(_LinearModelLayer, self).build(None)
546
547  def call(self, features):
548    if not isinstance(features, dict):
549      raise ValueError('We expected a dictionary here. Instead we got: {}'
550                       .format(features))
551    with ops.name_scope(self.name):
552      transformation_cache = FeatureTransformationCache(features)
553      weighted_sums = []
554      for column in self._feature_columns:
555        with ops.name_scope(column.name):
556          # All the weights used in the linear model are owned by the state
557          # manager associated with this Linear Model.
558          weight_var = self._state_manager.get_variable(column, 'weights')
559
560          weighted_sum = _create_weighted_sum(
561              column=column,
562              transformation_cache=transformation_cache,
563              state_manager=self._state_manager,
564              sparse_combiner=self._sparse_combiner,
565              weight_var=weight_var)
566          weighted_sums.append(weighted_sum)
567
568      _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
569      predictions_no_bias = math_ops.add_n(
570          weighted_sums, name='weighted_sum_no_bias')
571      predictions = nn_ops.bias_add(
572          predictions_no_bias, self.bias, name='weighted_sum')
573      return predictions
574
575
576@keras_export('keras.layers.LinearModel', v1=[])
577class LinearModel(training.Model):
578  """Produces a linear prediction `Tensor` based on given `feature_columns`.
579
580  This layer generates a weighted sum based on output dimension `units`.
581  Weighted sum refers to logits in classification problems. It refers to the
582  prediction itself for linear regression problems.
583
584  Note on supported columns: `LinearLayer` treats categorical columns as
585  `indicator_column`s. To be specific, assume the input as `SparseTensor` looks
586  like:
587
588  ```python
589    shape = [2, 2]
590    {
591        [0, 0]: "a"
592        [1, 0]: "b"
593        [1, 1]: "c"
594    }
595  ```
596  `linear_model` assigns weights for the presence of "a", "b", "c' implicitly,
597  just like `indicator_column`, while `input_layer` explicitly requires wrapping
598  each of categorical columns with an `embedding_column` or an
599  `indicator_column`.
600
601  Example of usage:
602
603  ```python
604  price = numeric_column('price')
605  price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.])
606  keywords = categorical_column_with_hash_bucket("keywords", 10K)
607  keywords_price = crossed_column('keywords', price_buckets, ...)
608  columns = [price_buckets, keywords, keywords_price ...]
609  linear_model = LinearLayer(columns)
610
611  features = tf.parse_example(..., features=make_parse_example_spec(columns))
612  prediction = linear_model(features)
613  ```
614  """
615
616  def __init__(self,
617               feature_columns,
618               units=1,
619               sparse_combiner='sum',
620               trainable=True,
621               name=None,
622               **kwargs):
623    """Constructs a LinearLayer.
624
625    Args:
626      feature_columns: An iterable containing the FeatureColumns to use as
627        inputs to your model. All items should be instances of classes derived
628        from `_FeatureColumn`s.
629      units: An integer, dimensionality of the output space. Default value is 1.
630      sparse_combiner: A string specifying how to reduce if a categorical column
631        is multivalent. Except `numeric_column`, almost all columns passed to
632        `linear_model` are considered as categorical columns.  It combines each
633        categorical column independently. Currently "mean", "sqrtn" and "sum"
634        are supported, with "sum" the default for linear model. "sqrtn" often
635        achieves good accuracy, in particular with bag-of-words columns.
636          * "sum": do not normalize features in the column
637          * "mean": do l1 normalization on features in the column
638          * "sqrtn": do l2 normalization on features in the column
639        For example, for two features represented as the categorical columns:
640
641          ```python
642          # Feature 1
643
644          shape = [2, 2]
645          {
646              [0, 0]: "a"
647              [0, 1]: "b"
648              [1, 0]: "c"
649          }
650
651          # Feature 2
652
653          shape = [2, 3]
654          {
655              [0, 0]: "d"
656              [1, 0]: "e"
657              [1, 1]: "f"
658              [1, 2]: "g"
659          }
660          ```
661
662        with `sparse_combiner` as "mean", the linear model outputs conceptly are
663        ```
664        y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
665        y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
666        ```
667        where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
668        assigned to the presence of `x` in the input features.
669      trainable: If `True` also add the variable to the graph collection
670        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
671      name: Name to give to the Linear Model. All variables and ops created will
672        be scoped by this name.
673      **kwargs: Keyword arguments to construct a layer.
674
675    Raises:
676      ValueError: if an item in `feature_columns` is neither a `DenseColumn`
677        nor `CategoricalColumn`.
678    """
679
680    super(LinearModel, self).__init__(name=name, **kwargs)
681    self.layer = _LinearModelLayer(
682        feature_columns,
683        units,
684        sparse_combiner,
685        trainable,
686        name=self.name,
687        **kwargs)
688
689  def call(self, features):
690    """Returns a `Tensor` the represents the predictions of a linear model.
691
692    Args:
693      features: A mapping from key to tensors. `_FeatureColumn`s look up via
694        these keys. For example `numeric_column('price')` will look at 'price'
695        key in this dict. Values are `Tensor` or `SparseTensor` depending on
696        corresponding `_FeatureColumn`.
697
698    Returns:
699      A `Tensor` which represents predictions/logits of a linear model. Its
700      shape is (batch_size, units) and its dtype is `float32`.
701
702    Raises:
703      ValueError: If features are not a dictionary.
704    """
705    return self.layer(features)
706
707  @property
708  def bias(self):
709    return self.layer.bias
710
711
712def _transform_features_v2(features, feature_columns, state_manager):
713  """Returns transformed features based on features columns passed in.
714
715  Please note that most probably you would not need to use this function. Please
716  check `input_layer` and `linear_model` to see whether they will
717  satisfy your use case or not.
718
719  Example:
720
721  ```python
722  # Define features and transformations
723  crosses_a_x_b = crossed_column(
724      columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
725  price_buckets = bucketized_column(
726      source_column=numeric_column("price"), boundaries=[...])
727
728  columns = [crosses_a_x_b, price_buckets]
729  features = tf.parse_example(..., features=make_parse_example_spec(columns))
730  transformed = transform_features(features=features, feature_columns=columns)
731
732  assertCountEqual(columns, transformed.keys())
733  ```
734
735  Args:
736    features: A mapping from key to tensors. `FeatureColumn`s look up via these
737      keys. For example `numeric_column('price')` will look at 'price' key in
738      this dict. Values can be a `SparseTensor` or a `Tensor` depends on
739      corresponding `FeatureColumn`.
740    feature_columns: An iterable containing all the `FeatureColumn`s.
741    state_manager: A StateManager object that holds the FeatureColumn state.
742
743  Returns:
744    A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values.
745  """
746  feature_columns = _normalize_feature_columns(feature_columns)
747  outputs = {}
748  with ops.name_scope(
749      None, default_name='transform_features', values=features.values()):
750    transformation_cache = FeatureTransformationCache(features)
751    for column in feature_columns:
752      with ops.name_scope(None, default_name=column.name):
753        outputs[column] = transformation_cache.get(column, state_manager)
754  return outputs
755
756
757@tf_export('feature_column.make_parse_example_spec', v1=[])
758def make_parse_example_spec_v2(feature_columns):
759  """Creates parsing spec dictionary from input feature_columns.
760
761  The returned dictionary can be used as arg 'features' in `tf.parse_example`.
762
763  Typical usage example:
764
765  ```python
766  # Define features and transformations
767  feature_a = categorical_column_with_vocabulary_file(...)
768  feature_b = numeric_column(...)
769  feature_c_bucketized = bucketized_column(numeric_column("feature_c"), ...)
770  feature_a_x_feature_c = crossed_column(
771      columns=["feature_a", feature_c_bucketized], ...)
772
773  feature_columns = set(
774      [feature_b, feature_c_bucketized, feature_a_x_feature_c])
775  features = tf.parse_example(
776      serialized=serialized_examples,
777      features=make_parse_example_spec(feature_columns))
778  ```
779
780  For the above example, make_parse_example_spec would return the dict:
781
782  ```python
783  {
784      "feature_a": parsing_ops.VarLenFeature(tf.string),
785      "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32),
786      "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32)
787  }
788  ```
789
790  Args:
791    feature_columns: An iterable containing all feature columns. All items
792      should be instances of classes derived from `FeatureColumn`.
793
794  Returns:
795    A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`
796    value.
797
798  Raises:
799    ValueError: If any of the given `feature_columns` is not a `FeatureColumn`
800      instance.
801  """
802  result = {}
803  for column in feature_columns:
804    if not isinstance(column, FeatureColumn):
805      raise ValueError('All feature_columns must be FeatureColumn instances. '
806                       'Given: {}'.format(column))
807    config = column.parse_example_spec
808    for key, value in six.iteritems(config):
809      if key in result and value != result[key]:
810        raise ValueError(
811            'feature_columns contain different parse_spec for key '
812            '{}. Given {} and {}'.format(key, value, result[key]))
813    result.update(config)
814  return result
815
816
817@tf_export('feature_column.embedding_column')
818def embedding_column(categorical_column,
819                     dimension,
820                     combiner='mean',
821                     initializer=None,
822                     ckpt_to_load_from=None,
823                     tensor_name_in_ckpt=None,
824                     max_norm=None,
825                     trainable=True):
826  """`DenseColumn` that converts from sparse, categorical input.
827
828  Use this when your inputs are sparse, but you want to convert them to a dense
829  representation (e.g., to feed to a DNN).
830
831  Inputs must be a `CategoricalColumn` created by any of the
832  `categorical_column_*` function. Here is an example of using
833  `embedding_column` with `DNNClassifier`:
834
835  ```python
836  video_id = categorical_column_with_identity(
837      key='video_id', num_buckets=1000000, default_value=0)
838  columns = [embedding_column(video_id, 9),...]
839
840  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
841
842  label_column = ...
843  def input_fn():
844    features = tf.parse_example(
845        ..., features=make_parse_example_spec(columns + [label_column]))
846    labels = features.pop(label_column.name)
847    return features, labels
848
849  estimator.train(input_fn=input_fn, steps=100)
850  ```
851
852  Here is an example using `embedding_column` with model_fn:
853
854  ```python
855  def model_fn(features, ...):
856    video_id = categorical_column_with_identity(
857        key='video_id', num_buckets=1000000, default_value=0)
858    columns = [embedding_column(video_id, 9),...]
859    dense_tensor = input_layer(features, columns)
860    # Form DNN layers, calculate loss, and return EstimatorSpec.
861    ...
862  ```
863
864  Args:
865    categorical_column: A `CategoricalColumn` created by a
866      `categorical_column_with_*` function. This column produces the sparse IDs
867      that are inputs to the embedding lookup.
868    dimension: An integer specifying dimension of the embedding, must be > 0.
869    combiner: A string specifying how to reduce if there are multiple entries in
870      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
871      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
872      with bag-of-words columns. Each of this can be thought as example level
873      normalizations on the column. For more information, see
874      `tf.embedding_lookup_sparse`.
875    initializer: A variable initializer function to be used in embedding
876      variable initialization. If not specified, defaults to
877      `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
878      `1/sqrt(dimension)`.
879    ckpt_to_load_from: String representing checkpoint name/pattern from which to
880      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
881    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
882      to restore the column weights. Required if `ckpt_to_load_from` is not
883      `None`.
884    max_norm: If not `None`, embedding values are l2-normalized to this value.
885    trainable: Whether or not the embedding is trainable. Default is True.
886
887  Returns:
888    `DenseColumn` that converts from sparse input.
889
890  Raises:
891    ValueError: if `dimension` not > 0.
892    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
893      is specified.
894    ValueError: if `initializer` is specified and is not callable.
895    RuntimeError: If eager execution is enabled.
896  """
897  if (dimension is None) or (dimension < 1):
898    raise ValueError('Invalid dimension {}.'.format(dimension))
899  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
900    raise ValueError('Must specify both `ckpt_to_load_from` and '
901                     '`tensor_name_in_ckpt` or none of them.')
902
903  if (initializer is not None) and (not callable(initializer)):
904    raise ValueError('initializer must be callable if specified. '
905                     'Embedding of column_name: {}'.format(
906                         categorical_column.name))
907  if initializer is None:
908    initializer = init_ops.truncated_normal_initializer(
909        mean=0.0, stddev=1 / math.sqrt(dimension))
910
911  return EmbeddingColumn(
912      categorical_column=categorical_column,
913      dimension=dimension,
914      combiner=combiner,
915      initializer=initializer,
916      ckpt_to_load_from=ckpt_to_load_from,
917      tensor_name_in_ckpt=tensor_name_in_ckpt,
918      max_norm=max_norm,
919      trainable=trainable)
920
921
922@tf_export(v1=['feature_column.shared_embedding_columns'])
923def shared_embedding_columns(categorical_columns,
924                             dimension,
925                             combiner='mean',
926                             initializer=None,
927                             shared_embedding_collection_name=None,
928                             ckpt_to_load_from=None,
929                             tensor_name_in_ckpt=None,
930                             max_norm=None,
931                             trainable=True):
932  """List of dense columns that convert from sparse, categorical input.
933
934  This is similar to `embedding_column`, except that it produces a list of
935  embedding columns that share the same embedding weights.
936
937  Use this when your inputs are sparse and of the same type (e.g. watched and
938  impression video IDs that share the same vocabulary), and you want to convert
939  them to a dense representation (e.g., to feed to a DNN).
940
941  Inputs must be a list of categorical columns created by any of the
942  `categorical_column_*` function. They must all be of the same type and have
943  the same arguments except `key`. E.g. they can be
944  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
945  all columns could also be weighted_categorical_column.
946
947  Here is an example embedding of two features for a DNNClassifier model:
948
949  ```python
950  watched_video_id = categorical_column_with_vocabulary_file(
951      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
952  impression_video_id = categorical_column_with_vocabulary_file(
953      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
954  columns = shared_embedding_columns(
955      [watched_video_id, impression_video_id], dimension=10)
956
957  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
958
959  label_column = ...
960  def input_fn():
961    features = tf.parse_example(
962        ..., features=make_parse_example_spec(columns + [label_column]))
963    labels = features.pop(label_column.name)
964    return features, labels
965
966  estimator.train(input_fn=input_fn, steps=100)
967  ```
968
969  Here is an example using `shared_embedding_columns` with model_fn:
970
971  ```python
972  def model_fn(features, ...):
973    watched_video_id = categorical_column_with_vocabulary_file(
974        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
975    impression_video_id = categorical_column_with_vocabulary_file(
976        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
977    columns = shared_embedding_columns(
978        [watched_video_id, impression_video_id], dimension=10)
979    dense_tensor = input_layer(features, columns)
980    # Form DNN layers, calculate loss, and return EstimatorSpec.
981    ...
982  ```
983
984  Args:
985    categorical_columns: List of categorical columns created by a
986      `categorical_column_with_*` function. These columns produce the sparse IDs
987      that are inputs to the embedding lookup. All columns must be of the same
988      type and have the same arguments except `key`. E.g. they can be
989      categorical_column_with_vocabulary_file with the same vocabulary_file.
990      Some or all columns could also be weighted_categorical_column.
991    dimension: An integer specifying dimension of the embedding, must be > 0.
992    combiner: A string specifying how to reduce if there are multiple entries in
993      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
994      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
995      with bag-of-words columns. Each of this can be thought as example level
996      normalizations on the column. For more information, see
997      `tf.embedding_lookup_sparse`.
998    initializer: A variable initializer function to be used in embedding
999      variable initialization. If not specified, defaults to
1000      `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
1001      `1/sqrt(dimension)`.
1002    shared_embedding_collection_name: Optional name of the collection where
1003      shared embedding weights are added. If not given, a reasonable name will
1004      be chosen based on the names of `categorical_columns`. This is also used
1005      in `variable_scope` when creating shared embedding weights.
1006    ckpt_to_load_from: String representing checkpoint name/pattern from which to
1007      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
1008    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
1009      to restore the column weights. Required if `ckpt_to_load_from` is not
1010      `None`.
1011    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
1012      than this value, before combining.
1013    trainable: Whether or not the embedding is trainable. Default is True.
1014
1015  Returns:
1016    A list of dense columns that converts from sparse input. The order of
1017    results follows the ordering of `categorical_columns`.
1018
1019  Raises:
1020    ValueError: if `dimension` not > 0.
1021    ValueError: if any of the given `categorical_columns` is of different type
1022      or has different arguments than the others.
1023    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
1024      is specified.
1025    ValueError: if `initializer` is specified and is not callable.
1026    RuntimeError: if eager execution is enabled.
1027  """
1028  if context.executing_eagerly():
1029    raise RuntimeError('shared_embedding_columns are not supported when eager '
1030                       'execution is enabled.')
1031
1032  if (dimension is None) or (dimension < 1):
1033    raise ValueError('Invalid dimension {}.'.format(dimension))
1034  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
1035    raise ValueError('Must specify both `ckpt_to_load_from` and '
1036                     '`tensor_name_in_ckpt` or none of them.')
1037
1038  if (initializer is not None) and (not callable(initializer)):
1039    raise ValueError('initializer must be callable if specified.')
1040  if initializer is None:
1041    initializer = init_ops.truncated_normal_initializer(
1042        mean=0.0, stddev=1. / math.sqrt(dimension))
1043
1044  # Sort the columns so the default collection name is deterministic even if the
1045  # user passes columns from an unsorted collection, such as dict.values().
1046  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
1047
1048  c0 = sorted_columns[0]
1049  num_buckets = c0._num_buckets  # pylint: disable=protected-access
1050  if not isinstance(c0, fc_old._CategoricalColumn):  # pylint: disable=protected-access
1051    raise ValueError(
1052        'All categorical_columns must be subclasses of _CategoricalColumn. '
1053        'Given: {}, of type: {}'.format(c0, type(c0)))
1054  if isinstance(c0,
1055                (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn)):  # pylint: disable=protected-access
1056    c0 = c0.categorical_column
1057  for c in sorted_columns[1:]:
1058    if isinstance(
1059        c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn)):  # pylint: disable=protected-access
1060      c = c.categorical_column
1061    if not isinstance(c, type(c0)):
1062      raise ValueError(
1063          'To use shared_embedding_column, all categorical_columns must have '
1064          'the same type, or be weighted_categorical_column of the same type. '
1065          'Given column: {} of type: {} does not match given column: {} of '
1066          'type: {}'.format(c0, type(c0), c, type(c)))
1067    if num_buckets != c._num_buckets:  # pylint: disable=protected-access
1068      raise ValueError(
1069          'To use shared_embedding_column, all categorical_columns must have '
1070          'the same number of buckets. Given column: {} with buckets: {} does  '
1071          'not match column: {} with buckets: {}'.format(
1072              c0, num_buckets, c, c._num_buckets))  # pylint: disable=protected-access
1073
1074  if not shared_embedding_collection_name:
1075    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
1076    shared_embedding_collection_name += '_shared_embedding'
1077
1078  result = []
1079  for column in categorical_columns:
1080    result.append(
1081        fc_old._SharedEmbeddingColumn(  # pylint: disable=protected-access
1082            categorical_column=column,
1083            initializer=initializer,
1084            dimension=dimension,
1085            combiner=combiner,
1086            shared_embedding_collection_name=shared_embedding_collection_name,
1087            ckpt_to_load_from=ckpt_to_load_from,
1088            tensor_name_in_ckpt=tensor_name_in_ckpt,
1089            max_norm=max_norm,
1090            trainable=trainable))
1091
1092  return result
1093
1094
1095@tf_export('feature_column.shared_embeddings', v1=[])
1096def shared_embedding_columns_v2(categorical_columns,
1097                                dimension,
1098                                combiner='mean',
1099                                initializer=None,
1100                                shared_embedding_collection_name=None,
1101                                ckpt_to_load_from=None,
1102                                tensor_name_in_ckpt=None,
1103                                max_norm=None,
1104                                trainable=True):
1105  """List of dense columns that convert from sparse, categorical input.
1106
1107  This is similar to `embedding_column`, except that it produces a list of
1108  embedding columns that share the same embedding weights.
1109
1110  Use this when your inputs are sparse and of the same type (e.g. watched and
1111  impression video IDs that share the same vocabulary), and you want to convert
1112  them to a dense representation (e.g., to feed to a DNN).
1113
1114  Inputs must be a list of categorical columns created by any of the
1115  `categorical_column_*` function. They must all be of the same type and have
1116  the same arguments except `key`. E.g. they can be
1117  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
1118  all columns could also be weighted_categorical_column.
1119
1120  Here is an example embedding of two features for a DNNClassifier model:
1121
1122  ```python
1123  watched_video_id = categorical_column_with_vocabulary_file(
1124      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
1125  impression_video_id = categorical_column_with_vocabulary_file(
1126      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
1127  columns = shared_embedding_columns(
1128      [watched_video_id, impression_video_id], dimension=10)
1129
1130  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
1131
1132  label_column = ...
1133  def input_fn():
1134    features = tf.parse_example(
1135        ..., features=make_parse_example_spec(columns + [label_column]))
1136    labels = features.pop(label_column.name)
1137    return features, labels
1138
1139  estimator.train(input_fn=input_fn, steps=100)
1140  ```
1141
1142  Here is an example using `shared_embedding_columns` with model_fn:
1143
1144  ```python
1145  def model_fn(features, ...):
1146    watched_video_id = categorical_column_with_vocabulary_file(
1147        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
1148    impression_video_id = categorical_column_with_vocabulary_file(
1149        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
1150    columns = shared_embedding_columns(
1151        [watched_video_id, impression_video_id], dimension=10)
1152    dense_tensor = input_layer(features, columns)
1153    # Form DNN layers, calculate loss, and return EstimatorSpec.
1154    ...
1155  ```
1156
1157  Args:
1158    categorical_columns: List of categorical columns created by a
1159      `categorical_column_with_*` function. These columns produce the sparse IDs
1160      that are inputs to the embedding lookup. All columns must be of the same
1161      type and have the same arguments except `key`. E.g. they can be
1162      categorical_column_with_vocabulary_file with the same vocabulary_file.
1163      Some or all columns could also be weighted_categorical_column.
1164    dimension: An integer specifying dimension of the embedding, must be > 0.
1165    combiner: A string specifying how to reduce if there are multiple entries
1166      in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
1167      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
1168      with bag-of-words columns. Each of this can be thought as example level
1169      normalizations on the column. For more information, see
1170      `tf.embedding_lookup_sparse`.
1171    initializer: A variable initializer function to be used in embedding
1172      variable initialization. If not specified, defaults to
1173      `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
1174      `1/sqrt(dimension)`.
1175    shared_embedding_collection_name: Optional collective name of these columns.
1176      If not given, a reasonable name will be chosen based on the names of
1177      `categorical_columns`.
1178    ckpt_to_load_from: String representing checkpoint name/pattern from which to
1179      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
1180    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
1181      which to restore the column weights. Required if `ckpt_to_load_from` is
1182      not `None`.
1183    max_norm: If not `None`, each embedding is clipped if its l2-norm is
1184      larger than this value, before combining.
1185    trainable: Whether or not the embedding is trainable. Default is True.
1186
1187  Returns:
1188    A list of dense columns that converts from sparse input. The order of
1189    results follows the ordering of `categorical_columns`.
1190
1191  Raises:
1192    ValueError: if `dimension` not > 0.
1193    ValueError: if any of the given `categorical_columns` is of different type
1194      or has different arguments than the others.
1195    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
1196      is specified.
1197    ValueError: if `initializer` is specified and is not callable.
1198    RuntimeError: if eager execution is enabled.
1199  """
1200  if context.executing_eagerly():
1201    raise RuntimeError('shared_embedding_columns are not supported when eager '
1202                       'execution is enabled.')
1203
1204  if (dimension is None) or (dimension < 1):
1205    raise ValueError('Invalid dimension {}.'.format(dimension))
1206  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
1207    raise ValueError('Must specify both `ckpt_to_load_from` and '
1208                     '`tensor_name_in_ckpt` or none of them.')
1209
1210  if (initializer is not None) and (not callable(initializer)):
1211    raise ValueError('initializer must be callable if specified.')
1212  if initializer is None:
1213    initializer = init_ops.truncated_normal_initializer(
1214        mean=0.0, stddev=1. / math.sqrt(dimension))
1215
1216  # Sort the columns so the default collection name is deterministic even if the
1217  # user passes columns from an unsorted collection, such as dict.values().
1218  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
1219
1220  c0 = sorted_columns[0]
1221  num_buckets = c0.num_buckets
1222  if not isinstance(c0, CategoricalColumn):
1223    raise ValueError(
1224        'All categorical_columns must be subclasses of CategoricalColumn. '
1225        'Given: {}, of type: {}'.format(c0, type(c0)))
1226  if isinstance(c0, WeightedCategoricalColumn):
1227    c0 = c0.categorical_column
1228  for c in sorted_columns[1:]:
1229    if isinstance(c, WeightedCategoricalColumn):
1230      c = c.categorical_column
1231    if not isinstance(c, type(c0)):
1232      raise ValueError(
1233          'To use shared_embedding_column, all categorical_columns must have '
1234          'the same type, or be weighted_categorical_column of the same type. '
1235          'Given column: {} of type: {} does not match given column: {} of '
1236          'type: {}'.format(c0, type(c0), c, type(c)))
1237    if num_buckets != c.num_buckets:
1238      raise ValueError(
1239          'To use shared_embedding_column, all categorical_columns must have '
1240          'the same number of buckets. Given column: {} with buckets: {} does  '
1241          'not match column: {} with buckets: {}'.format(
1242              c0, num_buckets, c, c.num_buckets))
1243
1244  if not shared_embedding_collection_name:
1245    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
1246    shared_embedding_collection_name += '_shared_embedding'
1247
1248  column_creator = SharedEmbeddingColumnCreator(
1249      dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt,
1250      num_buckets, trainable, shared_embedding_collection_name)
1251
1252  result = []
1253  for column in categorical_columns:
1254    result.append(
1255        column_creator(
1256            categorical_column=column, combiner=combiner, max_norm=max_norm))
1257
1258  return result
1259
1260
1261@tf_export('feature_column.numeric_column')
1262def numeric_column(key,
1263                   shape=(1,),
1264                   default_value=None,
1265                   dtype=dtypes.float32,
1266                   normalizer_fn=None):
1267  """Represents real valued or numerical features.
1268
1269  Example:
1270
1271  ```python
1272  price = numeric_column('price')
1273  columns = [price, ...]
1274  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1275  dense_tensor = input_layer(features, columns)
1276
1277  # or
1278  bucketized_price = bucketized_column(price, boundaries=[...])
1279  columns = [bucketized_price, ...]
1280  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1281  linear_prediction = linear_model(features, columns)
1282  ```
1283
1284  Args:
1285    key: A unique string identifying the input feature. It is used as the
1286      column name and the dictionary key for feature parsing configs, feature
1287      `Tensor` objects, and feature columns.
1288    shape: An iterable of integers specifies the shape of the `Tensor`. An
1289      integer can be given which means a single dimension `Tensor` with given
1290      width. The `Tensor` representing the column will have the shape of
1291      [batch_size] + `shape`.
1292    default_value: A single value compatible with `dtype` or an iterable of
1293      values compatible with `dtype` which the column takes on during
1294      `tf.Example` parsing if data is missing. A default value of `None` will
1295      cause `tf.parse_example` to fail if an example does not contain this
1296      column. If a single value is provided, the same value will be applied as
1297      the default value for every item. If an iterable of values is provided,
1298      the shape of the `default_value` should be equal to the given `shape`.
1299    dtype: defines the type of values. Default value is `tf.float32`. Must be a
1300      non-quantized, real integer or floating point type.
1301    normalizer_fn: If not `None`, a function that can be used to normalize the
1302      value of the tensor after `default_value` is applied for parsing.
1303      Normalizer function takes the input `Tensor` as its argument, and returns
1304      the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
1305      even though the most common use case of this function is normalization, it
1306      can be used for any kind of Tensorflow transformations.
1307
1308  Returns:
1309    A `NumericColumn`.
1310
1311  Raises:
1312    TypeError: if any dimension in shape is not an int
1313    ValueError: if any dimension in shape is not a positive integer
1314    TypeError: if `default_value` is an iterable but not compatible with `shape`
1315    TypeError: if `default_value` is not compatible with `dtype`.
1316    ValueError: if `dtype` is not convertible to `tf.float32`.
1317  """
1318  shape = _check_shape(shape, key)
1319  if not (dtype.is_integer or dtype.is_floating):
1320    raise ValueError('dtype must be convertible to float. '
1321                     'dtype: {}, key: {}'.format(dtype, key))
1322  default_value = fc_utils.check_default_value(
1323      shape, default_value, dtype, key)
1324
1325  if normalizer_fn is not None and not callable(normalizer_fn):
1326    raise TypeError(
1327        'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
1328
1329  fc_utils.assert_key_is_string(key)
1330  return NumericColumn(
1331      key,
1332      shape=shape,
1333      default_value=default_value,
1334      dtype=dtype,
1335      normalizer_fn=normalizer_fn)
1336
1337
1338@tf_export('feature_column.bucketized_column')
1339def bucketized_column(source_column, boundaries):
1340  """Represents discretized dense input.
1341
1342  Buckets include the left boundary, and exclude the right boundary. Namely,
1343  `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
1344  `[1., 2.)`, and `[2., +inf)`.
1345
1346  For example, if the inputs are
1347
1348  ```python
1349  boundaries = [0, 10, 100]
1350  input tensor = [[-5, 10000]
1351                  [150,   10]
1352                  [5,    100]]
1353  ```
1354
1355  then the output will be
1356
1357  ```python
1358  output = [[0, 3]
1359            [3, 2]
1360            [1, 3]]
1361  ```
1362
1363  Example:
1364
1365  ```python
1366  price = numeric_column('price')
1367  bucketized_price = bucketized_column(price, boundaries=[...])
1368  columns = [bucketized_price, ...]
1369  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1370  linear_prediction = linear_model(features, columns)
1371
1372  # or
1373  columns = [bucketized_price, ...]
1374  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1375  dense_tensor = input_layer(features, columns)
1376  ```
1377
1378  `bucketized_column` can also be crossed with another categorical column using
1379  `crossed_column`:
1380
1381  ```python
1382  price = numeric_column('price')
1383  # bucketized_column converts numerical feature to a categorical one.
1384  bucketized_price = bucketized_column(price, boundaries=[...])
1385  # 'keywords' is a string feature.
1386  price_x_keywords = crossed_column([bucketized_price, 'keywords'], 50K)
1387  columns = [price_x_keywords, ...]
1388  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1389  linear_prediction = linear_model(features, columns)
1390  ```
1391
1392  Args:
1393    source_column: A one-dimensional dense column which is generated with
1394      `numeric_column`.
1395    boundaries: A sorted list or tuple of floats specifying the boundaries.
1396
1397  Returns:
1398    A `BucketizedColumn`.
1399
1400  Raises:
1401    ValueError: If `source_column` is not a numeric column, or if it is not
1402      one-dimensional.
1403    ValueError: If `boundaries` is not a sorted list or tuple.
1404  """
1405  if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)):  # pylint: disable=protected-access
1406    raise ValueError(
1407        'source_column must be a column generated with numeric_column(). '
1408        'Given: {}'.format(source_column))
1409  if len(source_column.shape) > 1:
1410    raise ValueError(
1411        'source_column must be one-dimensional column. '
1412        'Given: {}'.format(source_column))
1413  if not boundaries:
1414    raise ValueError('boundaries must not be empty.')
1415  if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)):
1416    raise ValueError('boundaries must be a sorted list.')
1417  for i in range(len(boundaries) - 1):
1418    if boundaries[i] >= boundaries[i + 1]:
1419      raise ValueError('boundaries must be a sorted list.')
1420  return BucketizedColumn(source_column, tuple(boundaries))
1421
1422
1423@tf_export('feature_column.categorical_column_with_hash_bucket')
1424def categorical_column_with_hash_bucket(key,
1425                                        hash_bucket_size,
1426                                        dtype=dtypes.string):
1427  """Represents sparse feature where ids are set by hashing.
1428
1429  Use this when your sparse features are in string or integer format, and you
1430  want to distribute your inputs into a finite number of buckets by hashing.
1431  output_id = Hash(input_feature_string) % bucket_size for string type input.
1432  For int type input, the value is converted to its string representation first
1433  and then hashed by the same formula.
1434
1435  For input dictionary `features`, `features[key]` is either `Tensor` or
1436  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1437  and `''` for string, which will be dropped by this feature column.
1438
1439  Example:
1440
1441  ```python
1442  keywords = categorical_column_with_hash_bucket("keywords", 10K)
1443  columns = [keywords, ...]
1444  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1445  linear_prediction = linear_model(features, columns)
1446
1447  # or
1448  keywords_embedded = embedding_column(keywords, 16)
1449  columns = [keywords_embedded, ...]
1450  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1451  dense_tensor = input_layer(features, columns)
1452  ```
1453
1454  Args:
1455    key: A unique string identifying the input feature. It is used as the
1456      column name and the dictionary key for feature parsing configs, feature
1457      `Tensor` objects, and feature columns.
1458    hash_bucket_size: An int > 1. The number of buckets.
1459    dtype: The type of features. Only string and integer types are supported.
1460
1461  Returns:
1462    A `HashedCategoricalColumn`.
1463
1464  Raises:
1465    ValueError: `hash_bucket_size` is not greater than 1.
1466    ValueError: `dtype` is neither string nor integer.
1467  """
1468  if hash_bucket_size is None:
1469    raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
1470
1471  if hash_bucket_size < 1:
1472    raise ValueError('hash_bucket_size must be at least 1. '
1473                     'hash_bucket_size: {}, key: {}'.format(
1474                         hash_bucket_size, key))
1475
1476  fc_utils.assert_key_is_string(key)
1477  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1478
1479  return HashedCategoricalColumn(key, hash_bucket_size, dtype)
1480
1481
1482@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file'])
1483def categorical_column_with_vocabulary_file(key,
1484                                            vocabulary_file,
1485                                            vocabulary_size=None,
1486                                            num_oov_buckets=0,
1487                                            default_value=None,
1488                                            dtype=dtypes.string):
1489  """A `CategoricalColumn` with a vocabulary file.
1490
1491  Use this when your inputs are in string or integer format, and you have a
1492  vocabulary file that maps each value to an integer ID. By default,
1493  out-of-vocabulary values are ignored. Use either (but not both) of
1494  `num_oov_buckets` and `default_value` to specify how to include
1495  out-of-vocabulary values.
1496
1497  For input dictionary `features`, `features[key]` is either `Tensor` or
1498  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1499  and `''` for string, which will be dropped by this feature column.
1500
1501  Example with `num_oov_buckets`:
1502  File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
1503  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1504  corresponding to its line number. All other values are hashed and assigned an
1505  ID 50-54.
1506
1507  ```python
1508  states = categorical_column_with_vocabulary_file(
1509      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
1510      num_oov_buckets=5)
1511  columns = [states, ...]
1512  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1513  linear_prediction = linear_model(features, columns)
1514  ```
1515
1516  Example with `default_value`:
1517  File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
1518  other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
1519  in input, and other values missing from the file, will be assigned ID 0. All
1520  others are assigned the corresponding line number 1-50.
1521
1522  ```python
1523  states = categorical_column_with_vocabulary_file(
1524      key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
1525      default_value=0)
1526  columns = [states, ...]
1527  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1528  linear_prediction, _, _ = linear_model(features, columns)
1529  ```
1530
1531  And to make an embedding with either:
1532
1533  ```python
1534  columns = [embedding_column(states, 3),...]
1535  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1536  dense_tensor = input_layer(features, columns)
1537  ```
1538
1539  Args:
1540    key: A unique string identifying the input feature. It is used as the
1541      column name and the dictionary key for feature parsing configs, feature
1542      `Tensor` objects, and feature columns.
1543    vocabulary_file: The vocabulary file name.
1544    vocabulary_size: Number of the elements in the vocabulary. This must be no
1545      greater than length of `vocabulary_file`, if less than length, later
1546      values are ignored. If None, it is set to the length of `vocabulary_file`.
1547    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1548      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1549      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1550      the input value. A positive `num_oov_buckets` can not be specified with
1551      `default_value`.
1552    default_value: The integer ID value to return for out-of-vocabulary feature
1553      values, defaults to `-1`. This can not be specified with a positive
1554      `num_oov_buckets`.
1555    dtype: The type of features. Only string and integer types are supported.
1556
1557  Returns:
1558    A `CategoricalColumn` with a vocabulary file.
1559
1560  Raises:
1561    ValueError: `vocabulary_file` is missing or cannot be opened.
1562    ValueError: `vocabulary_size` is missing or < 1.
1563    ValueError: `num_oov_buckets` is a negative integer.
1564    ValueError: `num_oov_buckets` and `default_value` are both specified.
1565    ValueError: `dtype` is neither string nor integer.
1566  """
1567  return categorical_column_with_vocabulary_file_v2(
1568      key, vocabulary_file, vocabulary_size,
1569      dtype, default_value,
1570      num_oov_buckets)
1571
1572
1573@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[])
1574def categorical_column_with_vocabulary_file_v2(key,
1575                                               vocabulary_file,
1576                                               vocabulary_size=None,
1577                                               dtype=dtypes.string,
1578                                               default_value=None,
1579                                               num_oov_buckets=0):
1580  """A `CategoricalColumn` with a vocabulary file.
1581
1582  Use this when your inputs are in string or integer format, and you have a
1583  vocabulary file that maps each value to an integer ID. By default,
1584  out-of-vocabulary values are ignored. Use either (but not both) of
1585  `num_oov_buckets` and `default_value` to specify how to include
1586  out-of-vocabulary values.
1587
1588  For input dictionary `features`, `features[key]` is either `Tensor` or
1589  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1590  and `''` for string, which will be dropped by this feature column.
1591
1592  Example with `num_oov_buckets`:
1593  File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
1594  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1595  corresponding to its line number. All other values are hashed and assigned an
1596  ID 50-54.
1597
1598  ```python
1599  states = categorical_column_with_vocabulary_file(
1600      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
1601      num_oov_buckets=5)
1602  columns = [states, ...]
1603  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1604  linear_prediction = linear_model(features, columns)
1605  ```
1606
1607  Example with `default_value`:
1608  File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
1609  other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
1610  in input, and other values missing from the file, will be assigned ID 0. All
1611  others are assigned the corresponding line number 1-50.
1612
1613  ```python
1614  states = categorical_column_with_vocabulary_file(
1615      key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
1616      default_value=0)
1617  columns = [states, ...]
1618  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1619  linear_prediction, _, _ = linear_model(features, columns)
1620  ```
1621
1622  And to make an embedding with either:
1623
1624  ```python
1625  columns = [embedding_column(states, 3),...]
1626  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1627  dense_tensor = input_layer(features, columns)
1628  ```
1629
1630  Args:
1631    key: A unique string identifying the input feature. It is used as the
1632      column name and the dictionary key for feature parsing configs, feature
1633      `Tensor` objects, and feature columns.
1634    vocabulary_file: The vocabulary file name.
1635    vocabulary_size: Number of the elements in the vocabulary. This must be no
1636      greater than length of `vocabulary_file`, if less than length, later
1637      values are ignored. If None, it is set to the length of `vocabulary_file`.
1638    dtype: The type of features. Only string and integer types are supported.
1639    default_value: The integer ID value to return for out-of-vocabulary feature
1640      values, defaults to `-1`. This can not be specified with a positive
1641      `num_oov_buckets`.
1642    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1643      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1644      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1645      the input value. A positive `num_oov_buckets` can not be specified with
1646      `default_value`.
1647
1648  Returns:
1649    A `CategoricalColumn` with a vocabulary file.
1650
1651  Raises:
1652    ValueError: `vocabulary_file` is missing or cannot be opened.
1653    ValueError: `vocabulary_size` is missing or < 1.
1654    ValueError: `num_oov_buckets` is a negative integer.
1655    ValueError: `num_oov_buckets` and `default_value` are both specified.
1656    ValueError: `dtype` is neither string nor integer.
1657  """
1658  if not vocabulary_file:
1659    raise ValueError('Missing vocabulary_file in {}.'.format(key))
1660
1661  if vocabulary_size is None:
1662    if not gfile.Exists(vocabulary_file):
1663      raise ValueError('vocabulary_file in {} does not exist.'.format(key))
1664
1665    with gfile.GFile(vocabulary_file) as f:
1666      vocabulary_size = sum(1 for _ in f)
1667    logging.info(
1668        'vocabulary_size = %d in %s is inferred from the number of elements '
1669        'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
1670
1671  # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
1672  if vocabulary_size < 1:
1673    raise ValueError('Invalid vocabulary_size in {}.'.format(key))
1674  if num_oov_buckets:
1675    if default_value is not None:
1676      raise ValueError(
1677          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1678              key))
1679    if num_oov_buckets < 0:
1680      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1681          num_oov_buckets, key))
1682  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1683  fc_utils.assert_key_is_string(key)
1684  return VocabularyFileCategoricalColumn(
1685      key=key,
1686      vocabulary_file=vocabulary_file,
1687      vocabulary_size=vocabulary_size,
1688      num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
1689      default_value=-1 if default_value is None else default_value,
1690      dtype=dtype)
1691
1692
1693@tf_export('feature_column.categorical_column_with_vocabulary_list')
1694def categorical_column_with_vocabulary_list(key,
1695                                            vocabulary_list,
1696                                            dtype=None,
1697                                            default_value=-1,
1698                                            num_oov_buckets=0):
1699  """A `CategoricalColumn` with in-memory vocabulary.
1700
1701  Use this when your inputs are in string or integer format, and you have an
1702  in-memory vocabulary mapping each value to an integer ID. By default,
1703  out-of-vocabulary values are ignored. Use either (but not both) of
1704  `num_oov_buckets` and `default_value` to specify how to include
1705  out-of-vocabulary values.
1706
1707  For input dictionary `features`, `features[key]` is either `Tensor` or
1708  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1709  and `''` for string, which will be dropped by this feature column.
1710
1711  Example with `num_oov_buckets`:
1712  In the following example, each input in `vocabulary_list` is assigned an ID
1713  0-3 corresponding to its index (e.g., input 'B' produces output 2). All other
1714  inputs are hashed and assigned an ID 4-5.
1715
1716  ```python
1717  colors = categorical_column_with_vocabulary_list(
1718      key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
1719      num_oov_buckets=2)
1720  columns = [colors, ...]
1721  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1722  linear_prediction, _, _ = linear_model(features, columns)
1723  ```
1724
1725  Example with `default_value`:
1726  In the following example, each input in `vocabulary_list` is assigned an ID
1727  0-4 corresponding to its index (e.g., input 'B' produces output 3). All other
1728  inputs are assigned `default_value` 0.
1729
1730
1731  ```python
1732  colors = categorical_column_with_vocabulary_list(
1733      key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
1734  columns = [colors, ...]
1735  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1736  linear_prediction, _, _ = linear_model(features, columns)
1737  ```
1738
1739  And to make an embedding with either:
1740
1741  ```python
1742  columns = [embedding_column(colors, 3),...]
1743  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1744  dense_tensor = input_layer(features, columns)
1745  ```
1746
1747  Args:
1748    key: A unique string identifying the input feature. It is used as the column
1749      name and the dictionary key for feature parsing configs, feature `Tensor`
1750      objects, and feature columns.
1751    vocabulary_list: An ordered iterable defining the vocabulary. Each feature
1752      is mapped to the index of its value (if present) in `vocabulary_list`.
1753      Must be castable to `dtype`.
1754    dtype: The type of features. Only string and integer types are supported. If
1755      `None`, it will be inferred from `vocabulary_list`.
1756    default_value: The integer ID value to return for out-of-vocabulary feature
1757      values, defaults to `-1`. This can not be specified with a positive
1758      `num_oov_buckets`.
1759    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1760      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1761      `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
1762      hash of the input value. A positive `num_oov_buckets` can not be specified
1763      with `default_value`.
1764
1765  Returns:
1766    A `CategoricalColumn` with in-memory vocabulary.
1767
1768  Raises:
1769    ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
1770    ValueError: `num_oov_buckets` is a negative integer.
1771    ValueError: `num_oov_buckets` and `default_value` are both specified.
1772    ValueError: if `dtype` is not integer or string.
1773  """
1774  if (vocabulary_list is None) or (len(vocabulary_list) < 1):
1775    raise ValueError(
1776        'vocabulary_list {} must be non-empty, column_name: {}'.format(
1777            vocabulary_list, key))
1778  if len(set(vocabulary_list)) != len(vocabulary_list):
1779    raise ValueError(
1780        'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
1781            vocabulary_list, key))
1782  vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
1783  if num_oov_buckets:
1784    if default_value != -1:
1785      raise ValueError(
1786          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1787              key))
1788    if num_oov_buckets < 0:
1789      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1790          num_oov_buckets, key))
1791  fc_utils.assert_string_or_int(
1792      vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
1793  if dtype is None:
1794    dtype = vocabulary_dtype
1795  elif dtype.is_integer != vocabulary_dtype.is_integer:
1796    raise ValueError(
1797        'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
1798            dtype, vocabulary_dtype, key))
1799  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1800  fc_utils.assert_key_is_string(key)
1801
1802  return VocabularyListCategoricalColumn(
1803      key=key,
1804      vocabulary_list=tuple(vocabulary_list),
1805      dtype=dtype,
1806      default_value=default_value,
1807      num_oov_buckets=num_oov_buckets)
1808
1809
1810@tf_export('feature_column.categorical_column_with_identity')
1811def categorical_column_with_identity(key, num_buckets, default_value=None):
1812  """A `CategoricalColumn` that returns identity values.
1813
1814  Use this when your inputs are integers in the range `[0, num_buckets)`, and
1815  you want to use the input value itself as the categorical ID. Values outside
1816  this range will result in `default_value` if specified, otherwise it will
1817  fail.
1818
1819  Typically, this is used for contiguous ranges of integer indexes, but
1820  it doesn't have to be. This might be inefficient, however, if many of IDs
1821  are unused. Consider `categorical_column_with_hash_bucket` in that case.
1822
1823  For input dictionary `features`, `features[key]` is either `Tensor` or
1824  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1825  and `''` for string, which will be dropped by this feature column.
1826
1827  In the following examples, each input in the range `[0, 1000000)` is assigned
1828  the same value. All other inputs are assigned `default_value` 0. Note that a
1829  literal 0 in inputs will result in the same default ID.
1830
1831  Linear model:
1832
1833  ```python
1834  video_id = categorical_column_with_identity(
1835      key='video_id', num_buckets=1000000, default_value=0)
1836  columns = [video_id, ...]
1837  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1838  linear_prediction, _, _ = linear_model(features, columns)
1839  ```
1840
1841  Embedding for a DNN model:
1842
1843  ```python
1844  columns = [embedding_column(video_id, 9),...]
1845  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1846  dense_tensor = input_layer(features, columns)
1847  ```
1848
1849  Args:
1850    key: A unique string identifying the input feature. It is used as the
1851      column name and the dictionary key for feature parsing configs, feature
1852      `Tensor` objects, and feature columns.
1853    num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
1854    default_value: If `None`, this column's graph operations will fail for
1855      out-of-range inputs. Otherwise, this value must be in the range
1856      `[0, num_buckets)`, and will replace inputs in that range.
1857
1858  Returns:
1859    A `CategoricalColumn` that returns identity values.
1860
1861  Raises:
1862    ValueError: if `num_buckets` is less than one.
1863    ValueError: if `default_value` is not in range `[0, num_buckets)`.
1864  """
1865  if num_buckets < 1:
1866    raise ValueError(
1867        'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
1868  if (default_value is not None) and (
1869      (default_value < 0) or (default_value >= num_buckets)):
1870    raise ValueError(
1871        'default_value {} not in range [0, {}), column_name {}'.format(
1872            default_value, num_buckets, key))
1873  fc_utils.assert_key_is_string(key)
1874  return IdentityCategoricalColumn(
1875      key=key, number_buckets=num_buckets, default_value=default_value)
1876
1877
1878@tf_export('feature_column.indicator_column')
1879def indicator_column(categorical_column):
1880  """Represents multi-hot representation of given categorical column.
1881
1882  - For DNN model, `indicator_column` can be used to wrap any
1883    `categorical_column_*` (e.g., to feed to DNN). Consider to Use
1884    `embedding_column` if the number of buckets/unique(values) are large.
1885
1886  - For Wide (aka linear) model, `indicator_column` is the internal
1887    representation for categorical column when passing categorical column
1888    directly (as any element in feature_columns) to `linear_model`. See
1889    `linear_model` for details.
1890
1891  ```python
1892  name = indicator_column(categorical_column_with_vocabulary_list(
1893      'name', ['bob', 'george', 'wanda'])
1894  columns = [name, ...]
1895  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1896  dense_tensor = input_layer(features, columns)
1897
1898  dense_tensor == [[1, 0, 0]]  # If "name" bytes_list is ["bob"]
1899  dense_tensor == [[1, 0, 1]]  # If "name" bytes_list is ["bob", "wanda"]
1900  dense_tensor == [[2, 0, 0]]  # If "name" bytes_list is ["bob", "bob"]
1901  ```
1902
1903  Args:
1904    categorical_column: A `CategoricalColumn` which is created by
1905      `categorical_column_with_*` or `crossed_column` functions.
1906
1907  Returns:
1908    An `IndicatorColumn`.
1909  """
1910  return IndicatorColumn(categorical_column)
1911
1912
1913@tf_export('feature_column.weighted_categorical_column')
1914def weighted_categorical_column(categorical_column,
1915                                weight_feature_key,
1916                                dtype=dtypes.float32):
1917  """Applies weight values to a `CategoricalColumn`.
1918
1919  Use this when each of your sparse inputs has both an ID and a value. For
1920  example, if you're representing text documents as a collection of word
1921  frequencies, you can provide 2 parallel sparse input features ('terms' and
1922  'frequencies' below).
1923
1924  Example:
1925
1926  Input `tf.Example` objects:
1927
1928  ```proto
1929  [
1930    features {
1931      feature {
1932        key: "terms"
1933        value {bytes_list {value: "very" value: "model"}}
1934      }
1935      feature {
1936        key: "frequencies"
1937        value {float_list {value: 0.3 value: 0.1}}
1938      }
1939    },
1940    features {
1941      feature {
1942        key: "terms"
1943        value {bytes_list {value: "when" value: "course" value: "human"}}
1944      }
1945      feature {
1946        key: "frequencies"
1947        value {float_list {value: 0.4 value: 0.1 value: 0.2}}
1948      }
1949    }
1950  ]
1951  ```
1952
1953  ```python
1954  categorical_column = categorical_column_with_hash_bucket(
1955      column_name='terms', hash_bucket_size=1000)
1956  weighted_column = weighted_categorical_column(
1957      categorical_column=categorical_column, weight_feature_key='frequencies')
1958  columns = [weighted_column, ...]
1959  features = tf.parse_example(..., features=make_parse_example_spec(columns))
1960  linear_prediction, _, _ = linear_model(features, columns)
1961  ```
1962
1963  This assumes the input dictionary contains a `SparseTensor` for key
1964  'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have
1965  the same indices and dense shape.
1966
1967  Args:
1968    categorical_column: A `CategoricalColumn` created by
1969      `categorical_column_with_*` functions.
1970    weight_feature_key: String key for weight values.
1971    dtype: Type of weights, such as `tf.float32`. Only float and integer weights
1972      are supported.
1973
1974  Returns:
1975    A `CategoricalColumn` composed of two sparse features: one represents id,
1976    the other represents weight (value) of the id feature in that example.
1977
1978  Raises:
1979    ValueError: if `dtype` is not convertible to float.
1980  """
1981  if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
1982    raise ValueError('dtype {} is not convertible to float.'.format(dtype))
1983  return WeightedCategoricalColumn(
1984      categorical_column=categorical_column,
1985      weight_feature_key=weight_feature_key,
1986      dtype=dtype)
1987
1988
1989@tf_export('feature_column.crossed_column')
1990def crossed_column(keys, hash_bucket_size, hash_key=None):
1991  """Returns a column for performing crosses of categorical features.
1992
1993  Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
1994  the transformation can be thought of as:
1995    Hash(cartesian product of features) % `hash_bucket_size`
1996
1997  For example, if the input features are:
1998
1999  * SparseTensor referred by first key:
2000
2001    ```python
2002    shape = [2, 2]
2003    {
2004        [0, 0]: "a"
2005        [1, 0]: "b"
2006        [1, 1]: "c"
2007    }
2008    ```
2009
2010  * SparseTensor referred by second key:
2011
2012    ```python
2013    shape = [2, 1]
2014    {
2015        [0, 0]: "d"
2016        [1, 0]: "e"
2017    }
2018    ```
2019
2020  then crossed feature will look like:
2021
2022  ```python
2023   shape = [2, 2]
2024  {
2025      [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
2026      [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
2027      [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
2028  }
2029  ```
2030
2031  Here is an example to create a linear model with crosses of string features:
2032
2033  ```python
2034  keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
2035  columns = [keywords_x_doc_terms, ...]
2036  features = tf.parse_example(..., features=make_parse_example_spec(columns))
2037  linear_prediction = linear_model(features, columns)
2038  ```
2039
2040  You could also use vocabulary lookup before crossing:
2041
2042  ```python
2043  keywords = categorical_column_with_vocabulary_file(
2044      'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
2045  keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
2046  columns = [keywords_x_doc_terms, ...]
2047  features = tf.parse_example(..., features=make_parse_example_spec(columns))
2048  linear_prediction = linear_model(features, columns)
2049  ```
2050
2051  If an input feature is of numeric type, you can use
2052  `categorical_column_with_identity`, or `bucketized_column`, as in the example:
2053
2054  ```python
2055  # vertical_id is an integer categorical feature.
2056  vertical_id = categorical_column_with_identity('vertical_id', 10K)
2057  price = numeric_column('price')
2058  # bucketized_column converts numerical feature to a categorical one.
2059  bucketized_price = bucketized_column(price, boundaries=[...])
2060  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
2061  columns = [vertical_id_x_price, ...]
2062  features = tf.parse_example(..., features=make_parse_example_spec(columns))
2063  linear_prediction = linear_model(features, columns)
2064  ```
2065
2066  To use crossed column in DNN model, you need to add it in an embedding column
2067  as in this example:
2068
2069  ```python
2070  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
2071  vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
2072  dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...])
2073  ```
2074
2075  Args:
2076    keys: An iterable identifying the features to be crossed. Each element can
2077      be either:
2078      * string: Will use the corresponding feature which must be of string type.
2079      * `CategoricalColumn`: Will use the transformed tensor produced by this
2080        column. Does not support hashed categorical column.
2081    hash_bucket_size: An int > 1. The number of buckets.
2082    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
2083      function to combine the crosses fingerprints on SparseCrossOp (optional).
2084
2085  Returns:
2086    A `CrossedColumn`.
2087
2088  Raises:
2089    ValueError: If `len(keys) < 2`.
2090    ValueError: If any of the keys is neither a string nor `CategoricalColumn`.
2091    ValueError: If any of the keys is `HashedCategoricalColumn`.
2092    ValueError: If `hash_bucket_size < 1`.
2093  """
2094  if not hash_bucket_size or hash_bucket_size < 1:
2095    raise ValueError('hash_bucket_size must be > 1. '
2096                     'hash_bucket_size: {}'.format(hash_bucket_size))
2097  if not keys or len(keys) < 2:
2098    raise ValueError(
2099        'keys must be a list with length > 1. Given: {}'.format(keys))
2100  for key in keys:
2101    if (not isinstance(key, six.string_types) and
2102        not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))):  # pylint: disable=protected-access
2103      raise ValueError(
2104          'Unsupported key type. All keys must be either string, or '
2105          'categorical column except HashedCategoricalColumn. '
2106          'Given: {}'.format(key))
2107    if isinstance(key,
2108                  (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)):  # pylint: disable=protected-access
2109      raise ValueError(
2110          'categorical_column_with_hash_bucket is not supported for crossing. '
2111          'Hashing before crossing will increase probability of collision. '
2112          'Instead, use the feature name as a string. Given: {}'.format(key))
2113  return CrossedColumn(
2114      keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
2115
2116
2117@six.add_metaclass(abc.ABCMeta)
2118class FeatureColumn(object):
2119  """Represents a feature column abstraction.
2120
2121  WARNING: Do not subclass this layer unless you know what you are doing:
2122  the API is subject to future changes.
2123
2124  To distinguish between the concept of a feature family and a specific binary
2125  feature within a family, we refer to a feature family like "country" as a
2126  feature column. For example, we can have a feature in a `tf.Example` format:
2127    {key: "country",  value: [ "US" ]}
2128  In this example the value of feature is "US" and "country" refers to the
2129  column of the feature.
2130
2131  This class is an abstract class. Users should not create instances of this.
2132  """
2133
2134  @abc.abstractproperty
2135  def name(self):
2136    """Returns string. Used for naming."""
2137    pass
2138
2139  @abc.abstractmethod
2140  def transform_feature(self, transformation_cache, state_manager):
2141    """Returns intermediate representation (usually a `Tensor`).
2142
2143    Uses `transformation_cache` to create an intermediate representation
2144    (usually a `Tensor`) that other feature columns can use.
2145
2146    Example usage of `transformation_cache`:
2147    Let's say a Feature column depends on raw feature ('raw') and another
2148    `FeatureColumn` (input_fc). To access corresponding `Tensor`s,
2149    transformation_cache will be used as follows:
2150
2151    ```python
2152    raw_tensor = transformation_cache.get('raw', state_manager)
2153    fc_tensor = transformation_cache.get(input_fc, state_manager)
2154    ```
2155
2156    Args:
2157      transformation_cache: A `FeatureTransformationCache` object to access
2158        features.
2159      state_manager: A `StateManager` to create / access resources such as
2160        lookup tables.
2161
2162    Returns:
2163      Transformed feature `Tensor`.
2164    """
2165    pass
2166
2167  @abc.abstractproperty
2168  def parse_example_spec(self):
2169    """Returns a `tf.Example` parsing spec as dict.
2170
2171    It is used for get_parsing_spec for `tf.parse_example`. Returned spec is a
2172    dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
2173    supported objects. Please check documentation of `tf.parse_example` for all
2174    supported spec objects.
2175
2176    Let's say a Feature column depends on raw feature ('raw') and another
2177    `FeatureColumn` (input_fc). One possible implementation of
2178    parse_example_spec is as follows:
2179
2180    ```python
2181    spec = {'raw': tf.FixedLenFeature(...)}
2182    spec.update(input_fc.parse_example_spec)
2183    return spec
2184    ```
2185    """
2186    pass
2187
2188  def create_state(self, state_manager):
2189    """Uses the `state_manager` to create state for the FeatureColumn.
2190
2191    Args:
2192      state_manager: A `StateManager` to create / access resources such as
2193        lookup tables and variables.
2194    """
2195    pass
2196
2197  @abc.abstractproperty
2198  def _is_v2_column(self):
2199    """Returns whether this FeatureColumn is fully conformant to the new API.
2200
2201    This is needed for composition type cases where an EmbeddingColumn etc.
2202    might take in old categorical columns as input and then we want to use the
2203    old API.
2204    """
2205    pass
2206
2207  @abc.abstractproperty
2208  def parents(self):
2209    """Returns a list of immediate raw feature and FeatureColumn dependencies.
2210
2211    For example:
2212    # For the following feature columns
2213    a = numeric_column('f1')
2214    c = crossed_column(a, 'f2')
2215    # The expected parents are:
2216    a.parents = ['f1']
2217    c.parents = [a, 'f2']
2218    """
2219    pass
2220
2221  @abc.abstractmethod
2222  def _get_config(self):
2223    """Returns the config of the feature column.
2224
2225    A FeatureColumn config is a Python dictionary (serializable) containing the
2226    configuration of a FeatureColumn. The same FeatureColumn can be
2227    reinstantiated later from this configuration.
2228
2229    The config of a feature column does not include information about feature
2230    columns depending on it nor the FeatureColumn class name.
2231
2232    Example with (de)serialization practices followed in this file:
2233    ```python
2234    class SerializationExampleFeatureColumn(
2235        FeatureColumn, collections.namedtuple(
2236            'SerializationExampleFeatureColumn',
2237            ('dimension', 'parent', 'dtype', 'normalizer_fn'))):
2238
2239      def _get_config(self):
2240        # Create a dict from the namedtuple.
2241        # Python attribute literals can be directly copied from / to the config.
2242        # For example 'dimension', assuming it is an integer literal.
2243        config = dict(zip(self._fields, self))
2244
2245        # (De)serialization of parent FeatureColumns should use the provided
2246        # (de)serialize_feature_column() methods that take care of de-duping.
2247        config['parent'] = serialize_feature_column(self.parent)
2248
2249        # Many objects provide custom (de)serialization e.g: for tf.DType
2250        # tf.DType.name, tf.as_dtype() can be used.
2251        config['dtype'] = self.dtype.name
2252
2253        # Non-trivial dependencies should be Keras-(de)serializable.
2254        config['normalizer_fn'] = utils.serialize_keras_object(
2255            self.normalizer_fn)
2256
2257        return config
2258
2259      @classmethod
2260      def _from_config(cls, config, custom_objects=None, columns_by_name=None):
2261        # This should do the inverse transform from `_get_config` and construct
2262        # the namedtuple.
2263        kwargs = config.copy()
2264        kwargs['parent'] = deserialize_feature_column(
2265            config['parent'], custom_objects, columns_by_name)
2266        kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2267        kwargs['normalizer_fn'] = utils.deserialize_keras_object(
2268          config['normalizer_fn'], custom_objects=custom_objects)
2269        return cls(**kwargs)
2270
2271    ```
2272    Returns:
2273      A serializable Dict that can be used to deserialize the object with
2274      from_config.
2275    """
2276    pass
2277
2278  @classmethod
2279  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
2280    """Creates a FeatureColumn from its config.
2281
2282    This method should be the reverse of `_get_config`, capable of instantiating
2283    the same FeatureColumn from the config dictionary. See `_get_config` for an
2284    example of common (de)serialization practices followed in this file.
2285
2286    TODO(b/118939620): This is a private method until consensus is reached on
2287    supporting object deserialization deduping within Keras.
2288
2289    Args:
2290      config: A Dict config acquired with `_get_config`.
2291      custom_objects: Optional dictionary mapping names (strings) to custom
2292        classes or functions to be considered during deserialization.
2293      columns_by_name: A Dict[String, FeatureColumn] of existing columns in
2294        order to avoid duplication. Should be passed to any calls to
2295        deserialize_feature_column().
2296
2297    Returns:
2298      A FeatureColumn for the input config.
2299    """
2300    pass
2301
2302
2303class DenseColumn(FeatureColumn):
2304  """Represents a column which can be represented as `Tensor`.
2305
2306  Some examples of this type are: numeric_column, embedding_column,
2307  indicator_column.
2308  """
2309
2310  @abc.abstractproperty
2311  def variable_shape(self):
2312    """`TensorShape` of `get_dense_tensor`, without batch dimension."""
2313    pass
2314
2315  @abc.abstractmethod
2316  def get_dense_tensor(self, transformation_cache, state_manager):
2317    """Returns a `Tensor`.
2318
2319    The output of this function will be used by model-builder-functions. For
2320    example the pseudo code of `input_layer` will be like:
2321
2322    ```python
2323    def input_layer(features, feature_columns, ...):
2324      outputs = [fc.get_dense_tensor(...) for fc in feature_columns]
2325      return tf.concat(outputs)
2326    ```
2327
2328    Args:
2329      transformation_cache: A `FeatureTransformationCache` object to access
2330        features.
2331      state_manager: A `StateManager` to create / access resources such as
2332        lookup tables.
2333
2334    Returns:
2335      `Tensor` of shape [batch_size] + `variable_shape`.
2336    """
2337    pass
2338
2339
2340def is_feature_column_v2(feature_columns):
2341  """Returns True if all feature columns are V2."""
2342  for feature_column in feature_columns:
2343    if not isinstance(feature_column, FeatureColumn):
2344      return False
2345    if not feature_column._is_v2_column:  # pylint: disable=protected-access
2346      return False
2347  return True
2348
2349
2350def _create_weighted_sum(column, transformation_cache, state_manager,
2351                         sparse_combiner, weight_var):
2352  """Creates a weighted sum for a dense/categorical column for linear_model."""
2353  if isinstance(column, CategoricalColumn):
2354    return _create_categorical_column_weighted_sum(
2355        column=column,
2356        transformation_cache=transformation_cache,
2357        state_manager=state_manager,
2358        sparse_combiner=sparse_combiner,
2359        weight_var=weight_var)
2360  else:
2361    return _create_dense_column_weighted_sum(
2362        column=column,
2363        transformation_cache=transformation_cache,
2364        state_manager=state_manager,
2365        weight_var=weight_var)
2366
2367
2368def _create_dense_column_weighted_sum(column, transformation_cache,
2369                                      state_manager, weight_var):
2370  """Create a weighted sum of a dense column for linear_model."""
2371  tensor = column.get_dense_tensor(transformation_cache, state_manager)
2372  num_elements = column.variable_shape.num_elements()
2373  batch_size = array_ops.shape(tensor)[0]
2374  tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
2375  return math_ops.matmul(tensor, weight_var, name='weighted_sum')
2376
2377
2378class CategoricalColumn(FeatureColumn):
2379  """Represents a categorical feature.
2380
2381  A categorical feature typically handled with a `tf.SparseTensor` of IDs.
2382  """
2383
2384  IdWeightPair = collections.namedtuple(  # pylint: disable=invalid-name
2385      'IdWeightPair', ('id_tensor', 'weight_tensor'))
2386
2387  @abc.abstractproperty
2388  def num_buckets(self):
2389    """Returns number of buckets in this sparse feature."""
2390    pass
2391
2392  @abc.abstractmethod
2393  def get_sparse_tensors(self, transformation_cache, state_manager):
2394    """Returns an IdWeightPair.
2395
2396    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
2397    weights.
2398
2399    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
2400    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
2401    `SparseTensor` of `float` or `None` to indicate all weights should be
2402    taken to be 1. If specified, `weight_tensor` must have exactly the same
2403    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
2404    output of a `VarLenFeature` which is a ragged matrix.
2405
2406    Args:
2407      transformation_cache: A `FeatureTransformationCache` object to access
2408        features.
2409      state_manager: A `StateManager` to create / access resources such as
2410        lookup tables.
2411    """
2412    pass
2413
2414
2415def _create_categorical_column_weighted_sum(
2416    column, transformation_cache, state_manager, sparse_combiner, weight_var):
2417  # pylint: disable=g-doc-return-or-yield,g-doc-args
2418  """Create a weighted sum of a categorical column for linear_model.
2419
2420  Note to maintainer: As implementation details, the weighted sum is
2421  implemented via embedding_lookup_sparse toward efficiency. Mathematically,
2422  they are the same.
2423
2424  To be specific, conceptually, categorical column can be treated as multi-hot
2425  vector. Say:
2426
2427  ```python
2428    x = [0 0 1]  # categorical column input
2429    w = [a b c]  # weights
2430  ```
2431  The weighted sum is `c` in this case, which is same as `w[2]`.
2432
2433  Another example is
2434
2435  ```python
2436    x = [0 1 1]  # categorical column input
2437    w = [a b c]  # weights
2438  ```
2439  The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
2440
2441  For both cases, we can implement weighted sum via embedding_lookup with
2442  sparse_combiner = "sum".
2443  """
2444
2445  sparse_tensors = column.get_sparse_tensors(transformation_cache,
2446                                             state_manager)
2447  id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [
2448      array_ops.shape(sparse_tensors.id_tensor)[0], -1
2449  ])
2450  weight_tensor = sparse_tensors.weight_tensor
2451  if weight_tensor is not None:
2452    weight_tensor = sparse_ops.sparse_reshape(
2453        weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
2454
2455  return embedding_ops.safe_embedding_lookup_sparse(
2456      weight_var,
2457      id_tensor,
2458      sparse_weights=weight_tensor,
2459      combiner=sparse_combiner,
2460      name='weighted_sum')
2461
2462
2463class SequenceDenseColumn(FeatureColumn):
2464  """Represents dense sequence data."""
2465
2466  TensorSequenceLengthPair = collections.namedtuple(  # pylint: disable=invalid-name
2467      'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length'))
2468
2469  @abc.abstractmethod
2470  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
2471    """Returns a `TensorSequenceLengthPair`.
2472
2473    Args:
2474      transformation_cache: A `FeatureTransformationCache` object to access
2475        features.
2476      state_manager: A `StateManager` to create / access resources such as
2477        lookup tables.
2478    """
2479    pass
2480
2481
2482class FeatureTransformationCache(object):
2483  """Handles caching of transformations while building the model.
2484
2485  `FeatureColumn` specifies how to digest an input column to the network. Some
2486  feature columns require data transformations. This class caches those
2487  transformations.
2488
2489  Some features may be used in more than one place. For example, one can use a
2490  bucketized feature by itself and a cross with it. In that case we
2491  should create only one bucketization op instead of creating ops for each
2492  feature column separately. To handle re-use of transformed columns,
2493  `FeatureTransformationCache` caches all previously transformed columns.
2494
2495  Example:
2496  We're trying to use the following `FeatureColumn`s:
2497
2498  ```python
2499  bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
2500  keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
2501  age_X_keywords = fc.crossed_column([bucketized_age, "keywords"])
2502  ... = linear_model(features,
2503                          [bucketized_age, keywords, age_X_keywords]
2504  ```
2505
2506  If we transform each column independently, then we'll get duplication of
2507  bucketization (one for cross, one for bucketization itself).
2508  The `FeatureTransformationCache` eliminates this duplication.
2509  """
2510
2511  def __init__(self, features):
2512    """Creates a `FeatureTransformationCache`.
2513
2514    Args:
2515      features: A mapping from feature column to objects that are `Tensor` or
2516        `SparseTensor`, or can be converted to same via
2517        `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
2518        signifies a base feature (not-transformed). A `FeatureColumn` key
2519        means that this `Tensor` is the output of an existing `FeatureColumn`
2520        which can be reused.
2521    """
2522    self._features = features.copy()
2523    self._feature_tensors = {}
2524
2525  def get(self, key, state_manager):
2526    """Returns a `Tensor` for the given key.
2527
2528    A `str` key is used to access a base feature (not-transformed). When a
2529    `FeatureColumn` is passed, the transformed feature is returned if it
2530    already exists, otherwise the given `FeatureColumn` is asked to provide its
2531    transformed output, which is then cached.
2532
2533    Args:
2534      key: a `str` or a `FeatureColumn`.
2535      state_manager: A StateManager object that holds the FeatureColumn state.
2536
2537    Returns:
2538      The transformed `Tensor` corresponding to the `key`.
2539
2540    Raises:
2541      ValueError: if key is not found or a transformed `Tensor` cannot be
2542        computed.
2543    """
2544    if key in self._feature_tensors:
2545      # FeatureColumn is already transformed or converted.
2546      return self._feature_tensors[key]
2547
2548    if key in self._features:
2549      feature_tensor = self._get_raw_feature_as_tensor(key)
2550      self._feature_tensors[key] = feature_tensor
2551      return feature_tensor
2552
2553    if isinstance(key, six.string_types):
2554      raise ValueError('Feature {} is not in features dictionary.'.format(key))
2555
2556    if not isinstance(key, FeatureColumn):
2557      raise TypeError('"key" must be either a "str" or "FeatureColumn". '
2558                      'Provided: {}'.format(key))
2559
2560    column = key
2561    logging.debug('Transforming feature_column %s.', column)
2562    transformed = column.transform_feature(self, state_manager)
2563    if transformed is None:
2564      raise ValueError('Column {} is not supported.'.format(column.name))
2565    self._feature_tensors[column] = transformed
2566    return transformed
2567
2568  def _get_raw_feature_as_tensor(self, key):
2569    """Gets the raw_feature (keyed by `key`) as `tensor`.
2570
2571    The raw feature is converted to (sparse) tensor and maybe expand dim.
2572
2573    For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
2574    the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
2575    error out as it is not supported.
2576
2577    Args:
2578      key: A `str` key to access the raw feature.
2579
2580    Returns:
2581      A `Tensor` or `SparseTensor`.
2582
2583    Raises:
2584      ValueError: if the raw feature has rank 0.
2585    """
2586    raw_feature = self._features[key]
2587    feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2588        raw_feature)
2589
2590    def expand_dims(input_tensor):
2591      # Input_tensor must have rank 1.
2592      if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2593        return sparse_ops.sparse_reshape(
2594            input_tensor, [array_ops.shape(input_tensor)[0], 1])
2595      else:
2596        return array_ops.expand_dims(input_tensor, -1)
2597
2598    rank = feature_tensor.get_shape().ndims
2599    if rank is not None:
2600      if rank == 0:
2601        raise ValueError(
2602            'Feature (key: {}) cannot have rank 0. Give: {}'.format(
2603                key, feature_tensor))
2604      return feature_tensor if rank != 1 else expand_dims(feature_tensor)
2605
2606    # Handle dynamic rank.
2607    with ops.control_dependencies([
2608        check_ops.assert_positive(
2609            array_ops.rank(feature_tensor),
2610            message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
2611                key, feature_tensor))]):
2612      return control_flow_ops.cond(
2613          math_ops.equal(1, array_ops.rank(feature_tensor)),
2614          lambda: expand_dims(feature_tensor),
2615          lambda: feature_tensor)
2616
2617
2618# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
2619def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
2620  """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
2621
2622  If `input_tensor` is already a `SparseTensor`, just return it.
2623
2624  Args:
2625    input_tensor: A string or integer `Tensor`.
2626    ignore_value: Entries in `dense_tensor` equal to this value will be
2627      absent from the resulting `SparseTensor`. If `None`, default value of
2628      `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
2629
2630  Returns:
2631    A `SparseTensor` with the same shape as `input_tensor`.
2632
2633  Raises:
2634    ValueError: when `input_tensor`'s rank is `None`.
2635  """
2636  input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2637      input_tensor)
2638  if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2639    return input_tensor
2640  with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
2641    if ignore_value is None:
2642      if input_tensor.dtype == dtypes.string:
2643        # Exception due to TF strings are converted to numpy objects by default.
2644        ignore_value = ''
2645      elif input_tensor.dtype.is_integer:
2646        ignore_value = -1  # -1 has a special meaning of missing feature
2647      else:
2648        # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
2649        # constructing a new numpy object of the given type, which yields the
2650        # default value for that type.
2651        ignore_value = input_tensor.dtype.as_numpy_dtype()
2652    ignore_value = math_ops.cast(
2653        ignore_value, input_tensor.dtype, name='ignore_value')
2654    indices = array_ops.where(
2655        math_ops.not_equal(input_tensor, ignore_value), name='indices')
2656    return sparse_tensor_lib.SparseTensor(
2657        indices=indices,
2658        values=array_ops.gather_nd(input_tensor, indices, name='values'),
2659        dense_shape=array_ops.shape(
2660            input_tensor, out_type=dtypes.int64, name='dense_shape'))
2661
2662
2663def _normalize_feature_columns(feature_columns):
2664  """Normalizes the `feature_columns` input.
2665
2666  This method converts the `feature_columns` to list type as best as it can. In
2667  addition, verifies the type and other parts of feature_columns, required by
2668  downstream library.
2669
2670  Args:
2671    feature_columns: The raw feature columns, usually passed by users.
2672
2673  Returns:
2674    The normalized feature column list.
2675
2676  Raises:
2677    ValueError: for any invalid inputs, such as empty, duplicated names, etc.
2678  """
2679  if isinstance(feature_columns, FeatureColumn):
2680    feature_columns = [feature_columns]
2681
2682  if isinstance(feature_columns, collections.Iterator):
2683    feature_columns = list(feature_columns)
2684
2685  if isinstance(feature_columns, dict):
2686    raise ValueError('Expected feature_columns to be iterable, found dict.')
2687
2688  for column in feature_columns:
2689    if not isinstance(column, FeatureColumn):
2690      raise ValueError('Items of feature_columns must be a FeatureColumn. '
2691                       'Given (type {}): {}.'.format(type(column), column))
2692  if not feature_columns:
2693    raise ValueError('feature_columns must not be empty.')
2694  name_to_column = dict()
2695  for column in feature_columns:
2696    if column.name in name_to_column:
2697      raise ValueError('Duplicate feature column name found for columns: {} '
2698                       'and {}. This usually means that these columns refer to '
2699                       'same base feature. Either one must be discarded or a '
2700                       'duplicated but renamed item must be inserted in '
2701                       'features dict.'.format(column,
2702                                               name_to_column[column.name]))
2703    name_to_column[column.name] = column
2704
2705  return sorted(feature_columns, key=lambda x: x.name)
2706
2707
2708class NumericColumn(
2709    DenseColumn,
2710    fc_old._DenseColumn,  # pylint: disable=protected-access
2711    collections.namedtuple(
2712        'NumericColumn',
2713        ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
2714  """see `numeric_column`."""
2715
2716  @property
2717  def _is_v2_column(self):
2718    return True
2719
2720  @property
2721  def name(self):
2722    """See `FeatureColumn` base class."""
2723    return self.key
2724
2725  @property
2726  def parse_example_spec(self):
2727    """See `FeatureColumn` base class."""
2728    return {
2729        self.key:
2730            parsing_ops.FixedLenFeature(self.shape, self.dtype,
2731                                        self.default_value)
2732    }
2733
2734  @property
2735  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2736                          _FEATURE_COLUMN_DEPRECATION)
2737  def _parse_example_spec(self):
2738    return self.parse_example_spec
2739
2740  def _transform_input_tensor(self, input_tensor):
2741    if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2742      raise ValueError(
2743          'The corresponding Tensor of numerical column must be a Tensor. '
2744          'SparseTensor is not supported. key: {}'.format(self.key))
2745    if self.normalizer_fn is not None:
2746      input_tensor = self.normalizer_fn(input_tensor)
2747    return math_ops.cast(input_tensor, dtypes.float32)
2748
2749  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2750                          _FEATURE_COLUMN_DEPRECATION)
2751  def _transform_feature(self, inputs):
2752    input_tensor = inputs.get(self.key)
2753    return self._transform_input_tensor(input_tensor)
2754
2755  def transform_feature(self, transformation_cache, state_manager):
2756    """See `FeatureColumn` base class.
2757
2758    In this case, we apply the `normalizer_fn` to the input tensor.
2759
2760    Args:
2761      transformation_cache: A `FeatureTransformationCache` object to access
2762        features.
2763      state_manager: A `StateManager` to create / access resources such as
2764        lookup tables.
2765
2766    Returns:
2767      Normalized input tensor.
2768    Raises:
2769      ValueError: If a SparseTensor is passed in.
2770    """
2771    input_tensor = transformation_cache.get(self.key, state_manager)
2772    return self._transform_input_tensor(input_tensor)
2773
2774  @property
2775  def variable_shape(self):
2776    """See `DenseColumn` base class."""
2777    return tensor_shape.TensorShape(self.shape)
2778
2779  @property
2780  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2781                          _FEATURE_COLUMN_DEPRECATION)
2782  def _variable_shape(self):
2783    return self.variable_shape
2784
2785  def get_dense_tensor(self, transformation_cache, state_manager):
2786    """Returns dense `Tensor` representing numeric feature.
2787
2788    Args:
2789      transformation_cache: A `FeatureTransformationCache` object to access
2790        features.
2791      state_manager: A `StateManager` to create / access resources such as
2792        lookup tables.
2793
2794    Returns:
2795      Dense `Tensor` created within `transform_feature`.
2796    """
2797    # Feature has been already transformed. Return the intermediate
2798    # representation created by _transform_feature.
2799    return transformation_cache.get(self, state_manager)
2800
2801  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2802                          _FEATURE_COLUMN_DEPRECATION)
2803  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2804    del weight_collections
2805    del trainable
2806    return inputs.get(self)
2807
2808  @property
2809  def parents(self):
2810    """See 'FeatureColumn` base class."""
2811    return [self.key]
2812
2813  def _get_config(self):
2814    """See 'FeatureColumn` base class."""
2815    config = dict(zip(self._fields, self))
2816    config['normalizer_fn'] = utils.serialize_keras_object(self.normalizer_fn)
2817    config['dtype'] = self.dtype.name
2818    return config
2819
2820  @classmethod
2821  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
2822    """See 'FeatureColumn` base class."""
2823    _check_config_keys(config, cls._fields)
2824    kwargs = config.copy()
2825    kwargs['normalizer_fn'] = utils.deserialize_keras_object(
2826        config['normalizer_fn'], custom_objects=custom_objects)
2827    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2828    return cls(**kwargs)
2829
2830
2831class BucketizedColumn(
2832    DenseColumn,
2833    CategoricalColumn,
2834    fc_old._DenseColumn,  # pylint: disable=protected-access
2835    fc_old._CategoricalColumn,  # pylint: disable=protected-access
2836    collections.namedtuple('BucketizedColumn',
2837                           ('source_column', 'boundaries'))):
2838  """See `bucketized_column`."""
2839
2840  @property
2841  def _is_v2_column(self):
2842    return (isinstance(self.source_column, FeatureColumn) and
2843            self.source_column._is_v2_column)  # pylint: disable=protected-access
2844
2845  @property
2846  def name(self):
2847    """See `FeatureColumn` base class."""
2848    return '{}_bucketized'.format(self.source_column.name)
2849
2850  @property
2851  def parse_example_spec(self):
2852    """See `FeatureColumn` base class."""
2853    return self.source_column.parse_example_spec
2854
2855  @property
2856  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2857                          _FEATURE_COLUMN_DEPRECATION)
2858  def _parse_example_spec(self):
2859    return self.source_column._parse_example_spec  # pylint: disable=protected-access
2860
2861  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2862                          _FEATURE_COLUMN_DEPRECATION)
2863  def _transform_feature(self, inputs):
2864    """Returns bucketized categorical `source_column` tensor."""
2865    source_tensor = inputs.get(self.source_column)
2866    return math_ops._bucketize(  # pylint: disable=protected-access
2867        source_tensor,
2868        boundaries=self.boundaries)
2869
2870  def transform_feature(self, transformation_cache, state_manager):
2871    """Returns bucketized categorical `source_column` tensor."""
2872    source_tensor = transformation_cache.get(self.source_column, state_manager)
2873    return math_ops._bucketize(  # pylint: disable=protected-access
2874        source_tensor,
2875        boundaries=self.boundaries)
2876
2877  @property
2878  def variable_shape(self):
2879    """See `DenseColumn` base class."""
2880    return tensor_shape.TensorShape(
2881        tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
2882
2883  @property
2884  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2885                          _FEATURE_COLUMN_DEPRECATION)
2886  def _variable_shape(self):
2887    return self.variable_shape
2888
2889  def _get_dense_tensor_for_input_tensor(self, input_tensor):
2890    return array_ops.one_hot(
2891        indices=math_ops.cast(input_tensor, dtypes.int64),
2892        depth=len(self.boundaries) + 1,
2893        on_value=1.,
2894        off_value=0.)
2895
2896  def get_dense_tensor(self, transformation_cache, state_manager):
2897    """Returns one hot encoded dense `Tensor`."""
2898    input_tensor = transformation_cache.get(self, state_manager)
2899    return self._get_dense_tensor_for_input_tensor(input_tensor)
2900
2901  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2902                          _FEATURE_COLUMN_DEPRECATION)
2903  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2904    del weight_collections
2905    del trainable
2906    input_tensor = inputs.get(self)
2907    return self._get_dense_tensor_for_input_tensor(input_tensor)
2908
2909  @property
2910  def num_buckets(self):
2911    """See `CategoricalColumn` base class."""
2912    # By construction, source_column is always one-dimensional.
2913    return (len(self.boundaries) + 1) * self.source_column.shape[0]
2914
2915  @property
2916  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2917                          _FEATURE_COLUMN_DEPRECATION)
2918  def _num_buckets(self):
2919    return self.num_buckets
2920
2921  def _get_sparse_tensors_for_input_tensor(self, input_tensor):
2922    batch_size = array_ops.shape(input_tensor)[0]
2923    # By construction, source_column is always one-dimensional.
2924    source_dimension = self.source_column.shape[0]
2925
2926    i1 = array_ops.reshape(
2927        array_ops.tile(
2928            array_ops.expand_dims(math_ops.range(0, batch_size), 1),
2929            [1, source_dimension]),
2930        (-1,))
2931    i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
2932    # Flatten the bucket indices and unique them across dimensions
2933    # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
2934    bucket_indices = (
2935        array_ops.reshape(input_tensor, (-1,)) +
2936        (len(self.boundaries) + 1) * i2)
2937
2938    indices = math_ops.cast(
2939        array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64)
2940    dense_shape = math_ops.cast(
2941        array_ops.stack([batch_size, source_dimension]), dtypes.int64)
2942    sparse_tensor = sparse_tensor_lib.SparseTensor(
2943        indices=indices,
2944        values=bucket_indices,
2945        dense_shape=dense_shape)
2946    return CategoricalColumn.IdWeightPair(sparse_tensor, None)
2947
2948  def get_sparse_tensors(self, transformation_cache, state_manager):
2949    """Converts dense inputs to SparseTensor so downstream code can use it."""
2950    input_tensor = transformation_cache.get(self, state_manager)
2951    return self._get_sparse_tensors_for_input_tensor(input_tensor)
2952
2953  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2954                          _FEATURE_COLUMN_DEPRECATION)
2955  def _get_sparse_tensors(self, inputs, weight_collections=None,
2956                          trainable=None):
2957    """Converts dense inputs to SparseTensor so downstream code can use it."""
2958    del weight_collections
2959    del trainable
2960    input_tensor = inputs.get(self)
2961    return self._get_sparse_tensors_for_input_tensor(input_tensor)
2962
2963  @property
2964  def parents(self):
2965    """See 'FeatureColumn` base class."""
2966    return [self.source_column]
2967
2968  def _get_config(self):
2969    """See 'FeatureColumn` base class."""
2970    config = dict(zip(self._fields, self))
2971    config['source_column'] = serialize_feature_column(self.source_column)
2972    return config
2973
2974  @classmethod
2975  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
2976    """See 'FeatureColumn` base class."""
2977    _check_config_keys(config, cls._fields)
2978    kwargs = config.copy()
2979    kwargs['source_column'] = deserialize_feature_column(
2980        config['source_column'], custom_objects, columns_by_name)
2981    return cls(**kwargs)
2982
2983
2984class EmbeddingColumn(
2985    DenseColumn,
2986    SequenceDenseColumn,
2987    fc_old._DenseColumn,  # pylint: disable=protected-access
2988    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
2989    collections.namedtuple(
2990        'EmbeddingColumn',
2991        ('categorical_column', 'dimension', 'combiner', 'initializer',
2992         'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
2993  """See `embedding_column`."""
2994
2995  @property
2996  def _is_v2_column(self):
2997    return (isinstance(self.categorical_column, FeatureColumn) and
2998            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
2999
3000  @property
3001  def name(self):
3002    """See `FeatureColumn` base class."""
3003    return '{}_embedding'.format(self.categorical_column.name)
3004
3005  @property
3006  def parse_example_spec(self):
3007    """See `FeatureColumn` base class."""
3008    return self.categorical_column.parse_example_spec
3009
3010  @property
3011  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3012                          _FEATURE_COLUMN_DEPRECATION)
3013  def _parse_example_spec(self):
3014    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
3015
3016  def transform_feature(self, transformation_cache, state_manager):
3017    """Transforms underlying `categorical_column`."""
3018    return transformation_cache.get(self.categorical_column, state_manager)
3019
3020  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3021                          _FEATURE_COLUMN_DEPRECATION)
3022  def _transform_feature(self, inputs):
3023    return inputs.get(self.categorical_column)
3024
3025  @property
3026  def variable_shape(self):
3027    """See `DenseColumn` base class."""
3028    return tensor_shape.vector(self.dimension)
3029
3030  @property
3031  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3032                          _FEATURE_COLUMN_DEPRECATION)
3033  def _variable_shape(self):
3034    return self.variable_shape
3035
3036  def create_state(self, state_manager):
3037    """Creates the embedding lookup variable."""
3038    embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
3039    state_manager.create_variable(
3040        self,
3041        name='embedding_weights',
3042        shape=embedding_shape,
3043        dtype=dtypes.float32,
3044        trainable=self.trainable,
3045        # TODO(rohanj): Make this True when b/118500434 is fixed.
3046        use_resource=False,
3047        initializer=self.initializer)
3048
3049  def _get_dense_tensor_internal_helper(self, sparse_tensors,
3050                                        embedding_weights):
3051    sparse_ids = sparse_tensors.id_tensor
3052    sparse_weights = sparse_tensors.weight_tensor
3053
3054    if self.ckpt_to_load_from is not None:
3055      to_restore = embedding_weights
3056      if isinstance(to_restore, variables.PartitionedVariable):
3057        to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
3058      checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
3059          self.tensor_name_in_ckpt: to_restore
3060      })
3061
3062    # Return embedding lookup result.
3063    return embedding_ops.safe_embedding_lookup_sparse(
3064        embedding_weights=embedding_weights,
3065        sparse_ids=sparse_ids,
3066        sparse_weights=sparse_weights,
3067        combiner=self.combiner,
3068        name='%s_weights' % self.name,
3069        max_norm=self.max_norm)
3070
3071  def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
3072    """Private method that follows the signature of get_dense_tensor."""
3073    embedding_weights = state_manager.get_variable(
3074        self, name='embedding_weights')
3075    return self._get_dense_tensor_internal_helper(sparse_tensors,
3076                                                  embedding_weights)
3077
3078  def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
3079                                     trainable):
3080    """Private method that follows the signature of _get_dense_tensor."""
3081    embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
3082    if (weight_collections and
3083        ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
3084      weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
3085    embedding_weights = variable_scope.get_variable(
3086        name='embedding_weights',
3087        shape=embedding_shape,
3088        dtype=dtypes.float32,
3089        initializer=self.initializer,
3090        trainable=self.trainable and trainable,
3091        collections=weight_collections)
3092    return self._get_dense_tensor_internal_helper(sparse_tensors,
3093                                                  embedding_weights)
3094
3095  def get_dense_tensor(self, transformation_cache, state_manager):
3096    """Returns tensor after doing the embedding lookup.
3097
3098    Args:
3099      transformation_cache: A `FeatureTransformationCache` object to access
3100        features.
3101      state_manager: A `StateManager` to create / access resources such as
3102        lookup tables.
3103
3104    Returns:
3105      Embedding lookup tensor.
3106
3107    Raises:
3108      ValueError: `categorical_column` is SequenceCategoricalColumn.
3109    """
3110    if isinstance(self.categorical_column, SequenceCategoricalColumn):
3111      raise ValueError(
3112          'In embedding_column: {}. '
3113          'categorical_column must not be of type SequenceCategoricalColumn. '
3114          'Suggested fix A: If you wish to use DenseFeatures, use a '
3115          'non-sequence categorical_column_with_*. '
3116          'Suggested fix B: If you wish to create sequence input, use '
3117          'SequenceFeatures instead of DenseFeatures. '
3118          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3119                                       self.categorical_column))
3120    # Get sparse IDs and weights.
3121    sparse_tensors = self.categorical_column.get_sparse_tensors(
3122        transformation_cache, state_manager)
3123    return self._get_dense_tensor_internal(sparse_tensors, state_manager)
3124
3125  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3126                          _FEATURE_COLUMN_DEPRECATION)
3127  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3128    if isinstance(
3129        self.categorical_column,
3130        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
3131      raise ValueError(
3132          'In embedding_column: {}. '
3133          'categorical_column must not be of type _SequenceCategoricalColumn. '
3134          'Suggested fix A: If you wish to use DenseFeatures, use a '
3135          'non-sequence categorical_column_with_*. '
3136          'Suggested fix B: If you wish to create sequence input, use '
3137          'SequenceFeatures instead of DenseFeatures. '
3138          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3139                                       self.categorical_column))
3140    sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
3141        inputs, weight_collections, trainable)
3142    return self._old_get_dense_tensor_internal(sparse_tensors,
3143                                               weight_collections, trainable)
3144
3145  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3146    """See `SequenceDenseColumn` base class."""
3147    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3148      raise ValueError(
3149          'In embedding_column: {}. '
3150          'categorical_column must be of type SequenceCategoricalColumn '
3151          'to use SequenceFeatures. '
3152          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3153          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3154                                       self.categorical_column))
3155    sparse_tensors = self.categorical_column.get_sparse_tensors(
3156        transformation_cache, state_manager)
3157    dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
3158                                                   state_manager)
3159    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3160        sparse_tensors.id_tensor)
3161    return SequenceDenseColumn.TensorSequenceLengthPair(
3162        dense_tensor=dense_tensor, sequence_length=sequence_length)
3163
3164  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3165                          _FEATURE_COLUMN_DEPRECATION)
3166  def _get_sequence_dense_tensor(self,
3167                                 inputs,
3168                                 weight_collections=None,
3169                                 trainable=None):
3170    if not isinstance(
3171        self.categorical_column,
3172        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
3173      raise ValueError(
3174          'In embedding_column: {}. '
3175          'categorical_column must be of type SequenceCategoricalColumn '
3176          'to use SequenceFeatures. '
3177          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3178          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3179                                       self.categorical_column))
3180    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3181    dense_tensor = self._old_get_dense_tensor_internal(
3182        sparse_tensors,
3183        weight_collections=weight_collections,
3184        trainable=trainable)
3185    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3186        sparse_tensors.id_tensor)
3187    return SequenceDenseColumn.TensorSequenceLengthPair(
3188        dense_tensor=dense_tensor, sequence_length=sequence_length)
3189
3190  @property
3191  def parents(self):
3192    """See 'FeatureColumn` base class."""
3193    return [self.categorical_column]
3194
3195  def _get_config(self):
3196    """See 'FeatureColumn` base class."""
3197    config = dict(zip(self._fields, self))
3198    config['categorical_column'] = serialize_feature_column(
3199        self.categorical_column)
3200    config['initializer'] = utils.serialize_keras_object(self.initializer)
3201    return config
3202
3203  @classmethod
3204  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
3205    """See 'FeatureColumn` base class."""
3206    _check_config_keys(config, cls._fields)
3207    kwargs = config.copy()
3208    kwargs['categorical_column'] = deserialize_feature_column(
3209        config['categorical_column'], custom_objects, columns_by_name)
3210    kwargs['initializer'] = utils.deserialize_keras_object(
3211        config['initializer'], custom_objects=custom_objects)
3212    return cls(**kwargs)
3213
3214
3215def _raise_shared_embedding_column_error():
3216  raise ValueError('SharedEmbeddingColumns are not supported in '
3217                   '`linear_model` or `input_layer`. Please use '
3218                   '`DenseFeatures` or `LinearModel` instead.')
3219
3220
3221class SharedEmbeddingColumnCreator(tracking.AutoTrackable):
3222
3223  def __init__(self,
3224               dimension,
3225               initializer,
3226               ckpt_to_load_from,
3227               tensor_name_in_ckpt,
3228               num_buckets,
3229               trainable,
3230               name='shared_embedding_column_creator'):
3231    self._dimension = dimension
3232    self._initializer = initializer
3233    self._ckpt_to_load_from = ckpt_to_load_from
3234    self._tensor_name_in_ckpt = tensor_name_in_ckpt
3235    self._num_buckets = num_buckets
3236    self._trainable = trainable
3237    self._name = name
3238    # Map from graph keys to embedding_weight variables.
3239    self._embedding_weights = {}
3240
3241  def __call__(self, categorical_column, combiner, max_norm):
3242    return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm)
3243
3244  @property
3245  def embedding_weights(self):
3246    key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
3247    if key not in self._embedding_weights:
3248      embedding_shape = (self._num_buckets, self._dimension)
3249      var = variable_scope.get_variable(
3250          name=self._name,
3251          shape=embedding_shape,
3252          dtype=dtypes.float32,
3253          initializer=self._initializer,
3254          trainable=self._trainable)
3255
3256      if self._ckpt_to_load_from is not None:
3257        to_restore = var
3258        if isinstance(to_restore, variables.PartitionedVariable):
3259          to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
3260        checkpoint_utils.init_from_checkpoint(
3261            self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore})
3262      self._embedding_weights[key] = var
3263    return self._embedding_weights[key]
3264
3265  @property
3266  def dimension(self):
3267    return self._dimension
3268
3269
3270class SharedEmbeddingColumn(
3271    DenseColumn,
3272    SequenceDenseColumn,
3273    fc_old._DenseColumn,  # pylint: disable=protected-access
3274    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
3275    collections.namedtuple(
3276        'SharedEmbeddingColumn',
3277        ('categorical_column', 'shared_embedding_column_creator', 'combiner',
3278         'max_norm'))):
3279  """See `embedding_column`."""
3280
3281  @property
3282  def _is_v2_column(self):
3283    return True
3284
3285  @property
3286  def name(self):
3287    """See `FeatureColumn` base class."""
3288    return '{}_shared_embedding'.format(self.categorical_column.name)
3289
3290  @property
3291  def parse_example_spec(self):
3292    """See `FeatureColumn` base class."""
3293    return self.categorical_column.parse_example_spec
3294
3295  @property
3296  def _parse_example_spec(self):
3297    return _raise_shared_embedding_column_error()
3298
3299  def transform_feature(self, transformation_cache, state_manager):
3300    """See `FeatureColumn` base class."""
3301    return transformation_cache.get(self.categorical_column, state_manager)
3302
3303  def _transform_feature(self, inputs):
3304    return _raise_shared_embedding_column_error()
3305
3306  @property
3307  def variable_shape(self):
3308    """See `DenseColumn` base class."""
3309    return tensor_shape.vector(self.shared_embedding_column_creator.dimension)
3310
3311  @property
3312  def _variable_shape(self):
3313    return _raise_shared_embedding_column_error()
3314
3315  def _get_dense_tensor_internal(self, transformation_cache, state_manager):
3316    """Private method that follows the signature of _get_dense_tensor."""
3317    # This method is called from a variable_scope with name _var_scope_name,
3318    # which is shared among all shared embeddings. Open a name_scope here, so
3319    # that the ops for different columns have distinct names.
3320    with ops.name_scope(None, default_name=self.name):
3321      # Get sparse IDs and weights.
3322      sparse_tensors = self.categorical_column.get_sparse_tensors(
3323          transformation_cache, state_manager)
3324      sparse_ids = sparse_tensors.id_tensor
3325      sparse_weights = sparse_tensors.weight_tensor
3326
3327      embedding_weights = self.shared_embedding_column_creator.embedding_weights
3328
3329      # Return embedding lookup result.
3330      return embedding_ops.safe_embedding_lookup_sparse(
3331          embedding_weights=embedding_weights,
3332          sparse_ids=sparse_ids,
3333          sparse_weights=sparse_weights,
3334          combiner=self.combiner,
3335          name='%s_weights' % self.name,
3336          max_norm=self.max_norm)
3337
3338  def get_dense_tensor(self, transformation_cache, state_manager):
3339    """Returns the embedding lookup result."""
3340    if isinstance(self.categorical_column, SequenceCategoricalColumn):
3341      raise ValueError(
3342          'In embedding_column: {}. '
3343          'categorical_column must not be of type SequenceCategoricalColumn. '
3344          'Suggested fix A: If you wish to use DenseFeatures, use a '
3345          'non-sequence categorical_column_with_*. '
3346          'Suggested fix B: If you wish to create sequence input, use '
3347          'SequenceFeatures instead of DenseFeatures. '
3348          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3349                                       self.categorical_column))
3350    return self._get_dense_tensor_internal(transformation_cache, state_manager)
3351
3352  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3353    return _raise_shared_embedding_column_error()
3354
3355  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3356    """See `SequenceDenseColumn` base class."""
3357    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3358      raise ValueError(
3359          'In embedding_column: {}. '
3360          'categorical_column must be of type SequenceCategoricalColumn '
3361          'to use SequenceFeatures. '
3362          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3363          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3364                                       self.categorical_column))
3365    dense_tensor = self._get_dense_tensor_internal(transformation_cache,
3366                                                   state_manager)
3367    sparse_tensors = self.categorical_column.get_sparse_tensors(
3368        transformation_cache, state_manager)
3369    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3370        sparse_tensors.id_tensor)
3371    return SequenceDenseColumn.TensorSequenceLengthPair(
3372        dense_tensor=dense_tensor, sequence_length=sequence_length)
3373
3374  def _get_sequence_dense_tensor(self,
3375                                 inputs,
3376                                 weight_collections=None,
3377                                 trainable=None):
3378    return _raise_shared_embedding_column_error()
3379
3380  @property
3381  def parents(self):
3382    """See 'FeatureColumn` base class."""
3383    return [self.categorical_column]
3384
3385  def _get_config(self):
3386    """See 'FeatureColumn` base class."""
3387    raise NotImplementedError()
3388
3389  @classmethod
3390  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
3391    """See 'FeatureColumn` base class."""
3392    raise NotImplementedError()
3393
3394
3395def _check_shape(shape, key):
3396  """Returns shape if it's valid, raises error otherwise."""
3397  assert shape is not None
3398  if not nest.is_sequence(shape):
3399    shape = [shape]
3400  shape = tuple(shape)
3401  for dimension in shape:
3402    if not isinstance(dimension, int):
3403      raise TypeError('shape dimensions must be integer. '
3404                      'shape: {}, key: {}'.format(shape, key))
3405    if dimension < 1:
3406      raise ValueError('shape dimensions must be greater than 0. '
3407                       'shape: {}, key: {}'.format(shape, key))
3408  return shape
3409
3410
3411class HashedCategoricalColumn(
3412    CategoricalColumn,
3413    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3414    collections.namedtuple('HashedCategoricalColumn',
3415                           ('key', 'hash_bucket_size', 'dtype'))):
3416  """see `categorical_column_with_hash_bucket`."""
3417
3418  @property
3419  def _is_v2_column(self):
3420    return True
3421
3422  @property
3423  def name(self):
3424    """See `FeatureColumn` base class."""
3425    return self.key
3426
3427  @property
3428  def parse_example_spec(self):
3429    """See `FeatureColumn` base class."""
3430    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3431
3432  @property
3433  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3434                          _FEATURE_COLUMN_DEPRECATION)
3435  def _parse_example_spec(self):
3436    return self.parse_example_spec
3437
3438  def _transform_input_tensor(self, input_tensor):
3439    """Hashes the values in the feature_column."""
3440    if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
3441      raise ValueError('SparseColumn input must be a SparseTensor.')
3442
3443    fc_utils.assert_string_or_int(
3444        input_tensor.dtype,
3445        prefix='column_name: {} input_tensor'.format(self.key))
3446
3447    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3448      raise ValueError(
3449          'Column dtype and SparseTensors dtype must be compatible. '
3450          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3451              self.key, self.dtype, input_tensor.dtype))
3452
3453    if self.dtype == dtypes.string:
3454      sparse_values = input_tensor.values
3455    else:
3456      sparse_values = string_ops.as_string(input_tensor.values)
3457
3458    sparse_id_values = string_ops.string_to_hash_bucket_fast(
3459        sparse_values, self.hash_bucket_size, name='lookup')
3460    return sparse_tensor_lib.SparseTensor(
3461        input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
3462
3463  def transform_feature(self, transformation_cache, state_manager):
3464    """Hashes the values in the feature_column."""
3465    input_tensor = _to_sparse_input_and_drop_ignore_values(
3466        transformation_cache.get(self.key, state_manager))
3467    return self._transform_input_tensor(input_tensor)
3468
3469  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3470                          _FEATURE_COLUMN_DEPRECATION)
3471  def _transform_feature(self, inputs):
3472    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3473    return self._transform_input_tensor(input_tensor)
3474
3475  @property
3476  def num_buckets(self):
3477    """Returns number of buckets in this sparse feature."""
3478    return self.hash_bucket_size
3479
3480  @property
3481  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3482                          _FEATURE_COLUMN_DEPRECATION)
3483  def _num_buckets(self):
3484    return self.num_buckets
3485
3486  def get_sparse_tensors(self, transformation_cache, state_manager):
3487    """See `CategoricalColumn` base class."""
3488    return CategoricalColumn.IdWeightPair(
3489        transformation_cache.get(self, state_manager), None)
3490
3491  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3492                          _FEATURE_COLUMN_DEPRECATION)
3493  def _get_sparse_tensors(self, inputs, weight_collections=None,
3494                          trainable=None):
3495    del weight_collections
3496    del trainable
3497    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3498
3499  @property
3500  def parents(self):
3501    """See 'FeatureColumn` base class."""
3502    return [self.key]
3503
3504  def _get_config(self):
3505    """See 'FeatureColumn` base class."""
3506    config = dict(zip(self._fields, self))
3507    config['dtype'] = self.dtype.name
3508    return config
3509
3510  @classmethod
3511  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
3512    """See 'FeatureColumn` base class."""
3513    _check_config_keys(config, cls._fields)
3514    kwargs = config.copy()
3515    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3516    return cls(**kwargs)
3517
3518
3519class VocabularyFileCategoricalColumn(
3520    CategoricalColumn,
3521    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3522    collections.namedtuple('VocabularyFileCategoricalColumn',
3523                           ('key', 'vocabulary_file', 'vocabulary_size',
3524                            'num_oov_buckets', 'dtype', 'default_value'))):
3525  """See `categorical_column_with_vocabulary_file`."""
3526
3527  @property
3528  def _is_v2_column(self):
3529    return True
3530
3531  @property
3532  def name(self):
3533    """See `FeatureColumn` base class."""
3534    return self.key
3535
3536  @property
3537  def parse_example_spec(self):
3538    """See `FeatureColumn` base class."""
3539    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3540
3541  @property
3542  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3543                          _FEATURE_COLUMN_DEPRECATION)
3544  def _parse_example_spec(self):
3545    return self.parse_example_spec
3546
3547  def _transform_input_tensor(self, input_tensor):
3548    """Creates a lookup table for the vocabulary."""
3549    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3550      raise ValueError(
3551          'Column dtype and SparseTensors dtype must be compatible. '
3552          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3553              self.key, self.dtype, input_tensor.dtype))
3554
3555    fc_utils.assert_string_or_int(
3556        input_tensor.dtype,
3557        prefix='column_name: {} input_tensor'.format(self.key))
3558
3559    key_dtype = self.dtype
3560    if input_tensor.dtype.is_integer:
3561      # `index_table_from_file` requires 64-bit integer keys.
3562      key_dtype = dtypes.int64
3563      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3564
3565    # TODO(rohanj): Use state manager to manage the index table creation.
3566    return lookup_ops.index_table_from_file(
3567        vocabulary_file=self.vocabulary_file,
3568        num_oov_buckets=self.num_oov_buckets,
3569        vocab_size=self.vocabulary_size,
3570        default_value=self.default_value,
3571        key_dtype=key_dtype,
3572        name='{}_lookup'.format(self.key)).lookup(input_tensor)
3573
3574  def transform_feature(self, transformation_cache, state_manager):
3575    """Creates a lookup table for the vocabulary."""
3576    input_tensor = _to_sparse_input_and_drop_ignore_values(
3577        transformation_cache.get(self.key, state_manager))
3578    return self._transform_input_tensor(input_tensor)
3579
3580  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3581                          _FEATURE_COLUMN_DEPRECATION)
3582  def _transform_feature(self, inputs):
3583    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3584    return self._transform_input_tensor(input_tensor)
3585
3586  @property
3587  def num_buckets(self):
3588    """Returns number of buckets in this sparse feature."""
3589    return self.vocabulary_size + self.num_oov_buckets
3590
3591  @property
3592  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3593                          _FEATURE_COLUMN_DEPRECATION)
3594  def _num_buckets(self):
3595    return self.num_buckets
3596
3597  def get_sparse_tensors(self, transformation_cache, state_manager):
3598    """See `CategoricalColumn` base class."""
3599    return CategoricalColumn.IdWeightPair(
3600        transformation_cache.get(self, state_manager), None)
3601
3602  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3603                          _FEATURE_COLUMN_DEPRECATION)
3604  def _get_sparse_tensors(self, inputs, weight_collections=None,
3605                          trainable=None):
3606    del weight_collections
3607    del trainable
3608    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3609
3610  @property
3611  def parents(self):
3612    """See 'FeatureColumn` base class."""
3613    return [self.key]
3614
3615  def _get_config(self):
3616    """See 'FeatureColumn` base class."""
3617    config = dict(zip(self._fields, self))
3618    config['dtype'] = self.dtype.name
3619    return config
3620
3621  @classmethod
3622  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
3623    """See 'FeatureColumn` base class."""
3624    _check_config_keys(config, cls._fields)
3625    kwargs = config.copy()
3626    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3627    return cls(**kwargs)
3628
3629
3630class VocabularyListCategoricalColumn(
3631    CategoricalColumn,
3632    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3633    collections.namedtuple(
3634        'VocabularyListCategoricalColumn',
3635        ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
3636):
3637  """See `categorical_column_with_vocabulary_list`."""
3638
3639  @property
3640  def _is_v2_column(self):
3641    return True
3642
3643  @property
3644  def name(self):
3645    """See `FeatureColumn` base class."""
3646    return self.key
3647
3648  @property
3649  def parse_example_spec(self):
3650    """See `FeatureColumn` base class."""
3651    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3652
3653  @property
3654  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3655                          _FEATURE_COLUMN_DEPRECATION)
3656  def _parse_example_spec(self):
3657    return self.parse_example_spec
3658
3659  def _transform_input_tensor(self, input_tensor):
3660    """Creates a lookup table for the vocabulary list."""
3661    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3662      raise ValueError(
3663          'Column dtype and SparseTensors dtype must be compatible. '
3664          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3665              self.key, self.dtype, input_tensor.dtype))
3666
3667    fc_utils.assert_string_or_int(
3668        input_tensor.dtype,
3669        prefix='column_name: {} input_tensor'.format(self.key))
3670
3671    key_dtype = self.dtype
3672    if input_tensor.dtype.is_integer:
3673      # `index_table_from_tensor` requires 64-bit integer keys.
3674      key_dtype = dtypes.int64
3675      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3676
3677    # TODO(rohanj): Use state manager to manage the index table creation.
3678    return lookup_ops.index_table_from_tensor(
3679        vocabulary_list=tuple(self.vocabulary_list),
3680        default_value=self.default_value,
3681        num_oov_buckets=self.num_oov_buckets,
3682        dtype=key_dtype,
3683        name='{}_lookup'.format(self.key)).lookup(input_tensor)
3684
3685  def transform_feature(self, transformation_cache, state_manager):
3686    """Creates a lookup table for the vocabulary list."""
3687    input_tensor = _to_sparse_input_and_drop_ignore_values(
3688        transformation_cache.get(self.key, state_manager))
3689    return self._transform_input_tensor(input_tensor)
3690
3691  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3692                          _FEATURE_COLUMN_DEPRECATION)
3693  def _transform_feature(self, inputs):
3694    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3695    return self._transform_input_tensor(input_tensor)
3696
3697  @property
3698  def num_buckets(self):
3699    """Returns number of buckets in this sparse feature."""
3700    return len(self.vocabulary_list) + self.num_oov_buckets
3701
3702  @property
3703  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3704                          _FEATURE_COLUMN_DEPRECATION)
3705  def _num_buckets(self):
3706    return self.num_buckets
3707
3708  def get_sparse_tensors(self, transformation_cache, state_manager):
3709    """See `CategoricalColumn` base class."""
3710    return CategoricalColumn.IdWeightPair(
3711        transformation_cache.get(self, state_manager), None)
3712
3713  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3714                          _FEATURE_COLUMN_DEPRECATION)
3715  def _get_sparse_tensors(self, inputs, weight_collections=None,
3716                          trainable=None):
3717    del weight_collections
3718    del trainable
3719    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3720
3721  @property
3722  def parents(self):
3723    """See 'FeatureColumn` base class."""
3724    return [self.key]
3725
3726  def _get_config(self):
3727    """See 'FeatureColumn` base class."""
3728    config = dict(zip(self._fields, self))
3729    config['dtype'] = self.dtype.name
3730    return config
3731
3732  @classmethod
3733  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
3734    """See 'FeatureColumn` base class."""
3735    _check_config_keys(config, cls._fields)
3736    kwargs = config.copy()
3737    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3738    return cls(**kwargs)
3739
3740
3741class IdentityCategoricalColumn(
3742    CategoricalColumn,
3743    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3744    collections.namedtuple('IdentityCategoricalColumn',
3745                           ('key', 'number_buckets', 'default_value'))):
3746
3747  """See `categorical_column_with_identity`."""
3748
3749  @property
3750  def _is_v2_column(self):
3751    return True
3752
3753  @property
3754  def name(self):
3755    """See `FeatureColumn` base class."""
3756    return self.key
3757
3758  @property
3759  def parse_example_spec(self):
3760    """See `FeatureColumn` base class."""
3761    return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
3762
3763  @property
3764  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3765                          _FEATURE_COLUMN_DEPRECATION)
3766  def _parse_example_spec(self):
3767    return self.parse_example_spec
3768
3769  def _transform_input_tensor(self, input_tensor):
3770    """Returns a SparseTensor with identity values."""
3771    if not input_tensor.dtype.is_integer:
3772      raise ValueError(
3773          'Invalid input, not integer. key: {} dtype: {}'.format(
3774              self.key, input_tensor.dtype))
3775
3776    values = math_ops.cast(input_tensor.values, dtypes.int64, name='values')
3777    num_buckets = math_ops.cast(
3778        self.num_buckets, dtypes.int64, name='num_buckets')
3779    zero = math_ops.cast(0, dtypes.int64, name='zero')
3780    if self.default_value is None:
3781      # Fail if values are out-of-range.
3782      assert_less = check_ops.assert_less(
3783          values, num_buckets, data=(values, num_buckets),
3784          name='assert_less_than_num_buckets')
3785      assert_greater = check_ops.assert_greater_equal(
3786          values, zero, data=(values,),
3787          name='assert_greater_or_equal_0')
3788      with ops.control_dependencies((assert_less, assert_greater)):
3789        values = array_ops.identity(values)
3790    else:
3791      # Assign default for out-of-range values.
3792      values = array_ops.where(
3793          math_ops.logical_or(
3794              values < zero, values >= num_buckets, name='out_of_range'),
3795          array_ops.fill(
3796              dims=array_ops.shape(values),
3797              value=math_ops.cast(self.default_value, dtypes.int64),
3798              name='default_values'), values)
3799
3800    return sparse_tensor_lib.SparseTensor(
3801        indices=input_tensor.indices,
3802        values=values,
3803        dense_shape=input_tensor.dense_shape)
3804
3805  def transform_feature(self, transformation_cache, state_manager):
3806    """Returns a SparseTensor with identity values."""
3807    input_tensor = _to_sparse_input_and_drop_ignore_values(
3808        transformation_cache.get(self.key, state_manager))
3809    return self._transform_input_tensor(input_tensor)
3810
3811  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3812                          _FEATURE_COLUMN_DEPRECATION)
3813  def _transform_feature(self, inputs):
3814    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3815    return self._transform_input_tensor(input_tensor)
3816
3817  @property
3818  def num_buckets(self):
3819    """Returns number of buckets in this sparse feature."""
3820    return self.number_buckets
3821
3822  @property
3823  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3824                          _FEATURE_COLUMN_DEPRECATION)
3825  def _num_buckets(self):
3826    return self.num_buckets
3827
3828  def get_sparse_tensors(self, transformation_cache, state_manager):
3829    """See `CategoricalColumn` base class."""
3830    return CategoricalColumn.IdWeightPair(
3831        transformation_cache.get(self, state_manager), None)
3832
3833  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3834                          _FEATURE_COLUMN_DEPRECATION)
3835  def _get_sparse_tensors(self, inputs, weight_collections=None,
3836                          trainable=None):
3837    del weight_collections
3838    del trainable
3839    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3840
3841  @property
3842  def parents(self):
3843    """See 'FeatureColumn` base class."""
3844    return [self.key]
3845
3846  def _get_config(self):
3847    """See 'FeatureColumn` base class."""
3848    return dict(zip(self._fields, self))
3849
3850  @classmethod
3851  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
3852    """See 'FeatureColumn` base class."""
3853    _check_config_keys(config, cls._fields)
3854    return cls(**config)
3855
3856
3857class WeightedCategoricalColumn(
3858    CategoricalColumn,
3859    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3860    collections.namedtuple(
3861        'WeightedCategoricalColumn',
3862        ('categorical_column', 'weight_feature_key', 'dtype'))):
3863  """See `weighted_categorical_column`."""
3864
3865  @property
3866  def _is_v2_column(self):
3867    return (isinstance(self.categorical_column, FeatureColumn) and
3868            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
3869
3870  @property
3871  def name(self):
3872    """See `FeatureColumn` base class."""
3873    return '{}_weighted_by_{}'.format(
3874        self.categorical_column.name, self.weight_feature_key)
3875
3876  @property
3877  def parse_example_spec(self):
3878    """See `FeatureColumn` base class."""
3879    config = self.categorical_column.parse_example_spec
3880    if self.weight_feature_key in config:
3881      raise ValueError('Parse config {} already exists for {}.'.format(
3882          config[self.weight_feature_key], self.weight_feature_key))
3883    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3884    return config
3885
3886  @property
3887  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3888                          _FEATURE_COLUMN_DEPRECATION)
3889  def _parse_example_spec(self):
3890    config = self.categorical_column._parse_example_spec  # pylint: disable=protected-access
3891    if self.weight_feature_key in config:
3892      raise ValueError('Parse config {} already exists for {}.'.format(
3893          config[self.weight_feature_key], self.weight_feature_key))
3894    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3895    return config
3896
3897  @property
3898  def num_buckets(self):
3899    """See `DenseColumn` base class."""
3900    return self.categorical_column.num_buckets
3901
3902  @property
3903  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3904                          _FEATURE_COLUMN_DEPRECATION)
3905  def _num_buckets(self):
3906    return self.categorical_column._num_buckets  # pylint: disable=protected-access
3907
3908  def _transform_weight_tensor(self, weight_tensor):
3909    if weight_tensor is None:
3910      raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
3911    weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
3912        weight_tensor)
3913    if self.dtype != weight_tensor.dtype.base_dtype:
3914      raise ValueError('Bad dtype, expected {}, but got {}.'.format(
3915          self.dtype, weight_tensor.dtype))
3916    if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
3917      # The weight tensor can be a regular Tensor. In this case, sparsify it.
3918      weight_tensor = _to_sparse_input_and_drop_ignore_values(
3919          weight_tensor, ignore_value=0.0)
3920    if not weight_tensor.dtype.is_floating:
3921      weight_tensor = math_ops.cast(weight_tensor, dtypes.float32)
3922    return weight_tensor
3923
3924  def transform_feature(self, transformation_cache, state_manager):
3925    """Applies weights to tensor generated from `categorical_column`'."""
3926    weight_tensor = transformation_cache.get(self.weight_feature_key,
3927                                             state_manager)
3928    weight_tensor = self._transform_weight_tensor(weight_tensor)
3929    return (transformation_cache.get(self.categorical_column, state_manager),
3930            weight_tensor)
3931
3932  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3933                          _FEATURE_COLUMN_DEPRECATION)
3934  def _transform_feature(self, inputs):
3935    """Applies weights to tensor generated from `categorical_column`'."""
3936    weight_tensor = inputs.get(self.weight_feature_key)
3937    weight_tensor = self._transform_weight_tensor(weight_tensor)
3938    return (inputs.get(self.categorical_column), weight_tensor)
3939
3940  def get_sparse_tensors(self, transformation_cache, state_manager):
3941    """See `CategoricalColumn` base class."""
3942    tensors = transformation_cache.get(self, state_manager)
3943    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
3944
3945  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3946                          _FEATURE_COLUMN_DEPRECATION)
3947  def _get_sparse_tensors(self, inputs, weight_collections=None,
3948                          trainable=None):
3949    del weight_collections
3950    del trainable
3951    tensors = inputs.get(self)
3952    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
3953
3954  @property
3955  def parents(self):
3956    """See 'FeatureColumn` base class."""
3957    return [self.categorical_column, self.weight_feature_key]
3958
3959  def _get_config(self):
3960    """See 'FeatureColumn` base class."""
3961    config = dict(zip(self._fields, self))
3962    config['categorical_column'] = serialize_feature_column(
3963        self.categorical_column)
3964    config['dtype'] = self.dtype.name
3965    return config
3966
3967  @classmethod
3968  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
3969    """See 'FeatureColumn` base class."""
3970    _check_config_keys(config, cls._fields)
3971    kwargs = config.copy()
3972    kwargs['categorical_column'] = deserialize_feature_column(
3973        config['categorical_column'], custom_objects, columns_by_name)
3974    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3975    return cls(**kwargs)
3976
3977
3978class CrossedColumn(
3979    CategoricalColumn,
3980    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3981    collections.namedtuple('CrossedColumn',
3982                           ('keys', 'hash_bucket_size', 'hash_key'))):
3983  """See `crossed_column`."""
3984
3985  @property
3986  def _is_v2_column(self):
3987    for key in _collect_leaf_level_keys(self):
3988      if isinstance(key, six.string_types):
3989        continue
3990      if not isinstance(key, FeatureColumn):
3991        return False
3992      if not key._is_v2_column:  # pylint: disable=protected-access
3993        return False
3994    return True
3995
3996  @property
3997  def name(self):
3998    """See `FeatureColumn` base class."""
3999    feature_names = []
4000    for key in _collect_leaf_level_keys(self):
4001      if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)):  # pylint: disable=protected-access
4002        feature_names.append(key.name)
4003      else:  # key must be a string
4004        feature_names.append(key)
4005    return '_X_'.join(sorted(feature_names))
4006
4007  @property
4008  def parse_example_spec(self):
4009    """See `FeatureColumn` base class."""
4010    config = {}
4011    for key in self.keys:
4012      if isinstance(key, FeatureColumn):
4013        config.update(key.parse_example_spec)
4014      elif isinstance(key, fc_old._FeatureColumn):  # pylint: disable=protected-access
4015        config.update(key._parse_example_spec)  # pylint: disable=protected-access
4016      else:  # key must be a string
4017        config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
4018    return config
4019
4020  @property
4021  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4022                          _FEATURE_COLUMN_DEPRECATION)
4023  def _parse_example_spec(self):
4024    return self.parse_example_spec
4025
4026  def transform_feature(self, transformation_cache, state_manager):
4027    """Generates a hashed sparse cross from the input tensors."""
4028    feature_tensors = []
4029    for key in _collect_leaf_level_keys(self):
4030      if isinstance(key, six.string_types):
4031        feature_tensors.append(transformation_cache.get(key, state_manager))
4032      elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)):  # pylint: disable=protected-access
4033        ids_and_weights = key.get_sparse_tensors(transformation_cache,
4034                                                 state_manager)
4035        if ids_and_weights.weight_tensor is not None:
4036          raise ValueError(
4037              'crossed_column does not support weight_tensor, but the given '
4038              'column populates weight_tensor. '
4039              'Given column: {}'.format(key.name))
4040        feature_tensors.append(ids_and_weights.id_tensor)
4041      else:
4042        raise ValueError('Unsupported column type. Given: {}'.format(key))
4043    return sparse_ops.sparse_cross_hashed(
4044        inputs=feature_tensors,
4045        num_buckets=self.hash_bucket_size,
4046        hash_key=self.hash_key)
4047
4048  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4049                          _FEATURE_COLUMN_DEPRECATION)
4050  def _transform_feature(self, inputs):
4051    """Generates a hashed sparse cross from the input tensors."""
4052    feature_tensors = []
4053    for key in _collect_leaf_level_keys(self):
4054      if isinstance(key, six.string_types):
4055        feature_tensors.append(inputs.get(key))
4056      elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)):  # pylint: disable=protected-access
4057        ids_and_weights = key._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4058        if ids_and_weights.weight_tensor is not None:
4059          raise ValueError(
4060              'crossed_column does not support weight_tensor, but the given '
4061              'column populates weight_tensor. '
4062              'Given column: {}'.format(key.name))
4063        feature_tensors.append(ids_and_weights.id_tensor)
4064      else:
4065        raise ValueError('Unsupported column type. Given: {}'.format(key))
4066    return sparse_ops.sparse_cross_hashed(
4067        inputs=feature_tensors,
4068        num_buckets=self.hash_bucket_size,
4069        hash_key=self.hash_key)
4070
4071  @property
4072  def num_buckets(self):
4073    """Returns number of buckets in this sparse feature."""
4074    return self.hash_bucket_size
4075
4076  @property
4077  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4078                          _FEATURE_COLUMN_DEPRECATION)
4079  def _num_buckets(self):
4080    return self.num_buckets
4081
4082  def get_sparse_tensors(self, transformation_cache, state_manager):
4083    """See `CategoricalColumn` base class."""
4084    return CategoricalColumn.IdWeightPair(
4085        transformation_cache.get(self, state_manager), None)
4086
4087  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4088                          _FEATURE_COLUMN_DEPRECATION)
4089  def _get_sparse_tensors(self, inputs, weight_collections=None,
4090                          trainable=None):
4091    """See `CategoricalColumn` base class."""
4092    del weight_collections
4093    del trainable
4094    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
4095
4096  @property
4097  def parents(self):
4098    """See 'FeatureColumn` base class."""
4099    return list(self.keys)
4100
4101  def _get_config(self):
4102    """See 'FeatureColumn` base class."""
4103    config = dict(zip(self._fields, self))
4104    config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys])
4105    return config
4106
4107  @classmethod
4108  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
4109    """See 'FeatureColumn` base class."""
4110    _check_config_keys(config, cls._fields)
4111    kwargs = config.copy()
4112    kwargs['keys'] = tuple([
4113        deserialize_feature_column(c, custom_objects, columns_by_name)
4114        for c in config['keys']
4115    ])
4116    return cls(**kwargs)
4117
4118
4119def _collect_leaf_level_keys(cross):
4120  """Collects base keys by expanding all nested crosses.
4121
4122  Args:
4123    cross: A `CrossedColumn`.
4124
4125  Returns:
4126    A list of strings or `CategoricalColumn` instances.
4127  """
4128  leaf_level_keys = []
4129  for k in cross.keys:
4130    if isinstance(k, CrossedColumn):
4131      leaf_level_keys.extend(_collect_leaf_level_keys(k))
4132    else:
4133      leaf_level_keys.append(k)
4134  return leaf_level_keys
4135
4136
4137def _prune_invalid_ids(sparse_ids, sparse_weights):
4138  """Prune invalid IDs (< 0) from the input ids and weights."""
4139  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
4140  if sparse_weights is not None:
4141    is_id_valid = math_ops.logical_and(
4142        is_id_valid,
4143        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
4144  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
4145  if sparse_weights is not None:
4146    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
4147  return sparse_ids, sparse_weights
4148
4149
4150def _prune_invalid_weights(sparse_ids, sparse_weights):
4151  """Prune invalid weights (< 0) from the input ids and weights."""
4152  if sparse_weights is not None:
4153    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
4154    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
4155    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
4156  return sparse_ids, sparse_weights
4157
4158
4159class IndicatorColumn(
4160    DenseColumn,
4161    SequenceDenseColumn,
4162    fc_old._DenseColumn,  # pylint: disable=protected-access
4163    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
4164    collections.namedtuple('IndicatorColumn', ('categorical_column'))):
4165  """Represents a one-hot column for use in deep networks.
4166
4167  Args:
4168    categorical_column: A `CategoricalColumn` which is created by
4169      `categorical_column_with_*` function.
4170  """
4171
4172  @property
4173  def _is_v2_column(self):
4174    return (isinstance(self.categorical_column, FeatureColumn) and
4175            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4176
4177  @property
4178  def name(self):
4179    """See `FeatureColumn` base class."""
4180    return '{}_indicator'.format(self.categorical_column.name)
4181
4182  def _transform_id_weight_pair(self, id_weight_pair):
4183    id_tensor = id_weight_pair.id_tensor
4184    weight_tensor = id_weight_pair.weight_tensor
4185
4186    # If the underlying column is weighted, return the input as a dense tensor.
4187    if weight_tensor is not None:
4188      weighted_column = sparse_ops.sparse_merge(
4189          sp_ids=id_tensor,
4190          sp_values=weight_tensor,
4191          vocab_size=int(self._variable_shape[-1]))
4192      # Remove (?, -1) index.
4193      weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
4194                                                weighted_column.dense_shape)
4195      # Use scatter_nd to merge duplicated indices if existed,
4196      # instead of sparse_tensor_to_dense.
4197      return array_ops.scatter_nd(weighted_column.indices,
4198                                  weighted_column.values,
4199                                  weighted_column.dense_shape)
4200
4201    dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
4202        id_tensor, default_value=-1)
4203
4204    # One hot must be float for tf.concat reasons since all other inputs to
4205    # input_layer are float32.
4206    one_hot_id_tensor = array_ops.one_hot(
4207        dense_id_tensor,
4208        depth=self._variable_shape[-1],
4209        on_value=1.0,
4210        off_value=0.0)
4211
4212    # Reduce to get a multi-hot per example.
4213    return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
4214
4215  def transform_feature(self, transformation_cache, state_manager):
4216    """Returns dense `Tensor` representing feature.
4217
4218    Args:
4219      transformation_cache: A `FeatureTransformationCache` object to access
4220        features.
4221      state_manager: A `StateManager` to create / access resources such as
4222        lookup tables.
4223
4224    Returns:
4225      Transformed feature `Tensor`.
4226
4227    Raises:
4228      ValueError: if input rank is not known at graph building time.
4229    """
4230    id_weight_pair = self.categorical_column.get_sparse_tensors(
4231        transformation_cache, state_manager)
4232    return self._transform_id_weight_pair(id_weight_pair)
4233
4234  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4235                          _FEATURE_COLUMN_DEPRECATION)
4236  def _transform_feature(self, inputs):
4237    id_weight_pair = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4238    return self._transform_id_weight_pair(id_weight_pair)
4239
4240  @property
4241  def parse_example_spec(self):
4242    """See `FeatureColumn` base class."""
4243    return self.categorical_column.parse_example_spec
4244
4245  @property
4246  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4247                          _FEATURE_COLUMN_DEPRECATION)
4248  def _parse_example_spec(self):
4249    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4250
4251  @property
4252  def variable_shape(self):
4253    """Returns a `TensorShape` representing the shape of the dense `Tensor`."""
4254    if isinstance(self.categorical_column, FeatureColumn):
4255      return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
4256    else:
4257      return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4258
4259  @property
4260  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4261                          _FEATURE_COLUMN_DEPRECATION)
4262  def _variable_shape(self):
4263    return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4264
4265  def get_dense_tensor(self, transformation_cache, state_manager):
4266    """Returns dense `Tensor` representing feature.
4267
4268    Args:
4269      transformation_cache: A `FeatureTransformationCache` object to access
4270        features.
4271      state_manager: A `StateManager` to create / access resources such as
4272        lookup tables.
4273
4274    Returns:
4275      Dense `Tensor` created within `transform_feature`.
4276
4277    Raises:
4278      ValueError: If `categorical_column` is a `SequenceCategoricalColumn`.
4279    """
4280    if isinstance(self.categorical_column, SequenceCategoricalColumn):
4281      raise ValueError(
4282          'In indicator_column: {}. '
4283          'categorical_column must not be of type SequenceCategoricalColumn. '
4284          'Suggested fix A: If you wish to use DenseFeatures, use a '
4285          'non-sequence categorical_column_with_*. '
4286          'Suggested fix B: If you wish to create sequence input, use '
4287          'SequenceFeatures instead of DenseFeatures. '
4288          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4289                                       self.categorical_column))
4290    # Feature has been already transformed. Return the intermediate
4291    # representation created by transform_feature.
4292    return transformation_cache.get(self, state_manager)
4293
4294  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4295                          _FEATURE_COLUMN_DEPRECATION)
4296  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
4297    del weight_collections
4298    del trainable
4299    if isinstance(
4300        self.categorical_column,
4301        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4302      raise ValueError(
4303          'In indicator_column: {}. '
4304          'categorical_column must not be of type _SequenceCategoricalColumn. '
4305          'Suggested fix A: If you wish to use DenseFeatures, use a '
4306          'non-sequence categorical_column_with_*. '
4307          'Suggested fix B: If you wish to create sequence input, use '
4308          'SequenceFeatures instead of DenseFeatures. '
4309          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4310                                       self.categorical_column))
4311    # Feature has been already transformed. Return the intermediate
4312    # representation created by transform_feature.
4313    return inputs.get(self)
4314
4315  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
4316    """See `SequenceDenseColumn` base class."""
4317    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
4318      raise ValueError(
4319          'In indicator_column: {}. '
4320          'categorical_column must be of type SequenceCategoricalColumn '
4321          'to use SequenceFeatures. '
4322          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4323          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4324                                       self.categorical_column))
4325    # Feature has been already transformed. Return the intermediate
4326    # representation created by transform_feature.
4327    dense_tensor = transformation_cache.get(self, state_manager)
4328    sparse_tensors = self.categorical_column.get_sparse_tensors(
4329        transformation_cache, state_manager)
4330    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4331        sparse_tensors.id_tensor)
4332    return SequenceDenseColumn.TensorSequenceLengthPair(
4333        dense_tensor=dense_tensor, sequence_length=sequence_length)
4334
4335  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4336                          _FEATURE_COLUMN_DEPRECATION)
4337  def _get_sequence_dense_tensor(self,
4338                                 inputs,
4339                                 weight_collections=None,
4340                                 trainable=None):
4341    # Do nothing with weight_collections and trainable since no variables are
4342    # created in this function.
4343    del weight_collections
4344    del trainable
4345    if not isinstance(
4346        self.categorical_column,
4347        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4348      raise ValueError(
4349          'In indicator_column: {}. '
4350          'categorical_column must be of type _SequenceCategoricalColumn '
4351          'to use SequenceFeatures. '
4352          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4353          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4354                                       self.categorical_column))
4355    # Feature has been already transformed. Return the intermediate
4356    # representation created by _transform_feature.
4357    dense_tensor = inputs.get(self)
4358    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4359    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4360        sparse_tensors.id_tensor)
4361    return SequenceDenseColumn.TensorSequenceLengthPair(
4362        dense_tensor=dense_tensor, sequence_length=sequence_length)
4363
4364  @property
4365  def parents(self):
4366    """See 'FeatureColumn` base class."""
4367    return [self.categorical_column]
4368
4369  def _get_config(self):
4370    """See 'FeatureColumn` base class."""
4371    config = dict(zip(self._fields, self))
4372    config['categorical_column'] = serialize_feature_column(
4373        self.categorical_column)
4374    return config
4375
4376  @classmethod
4377  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
4378    """See 'FeatureColumn` base class."""
4379    _check_config_keys(config, cls._fields)
4380    kwargs = config.copy()
4381    kwargs['categorical_column'] = deserialize_feature_column(
4382        config['categorical_column'], custom_objects, columns_by_name)
4383    return cls(**kwargs)
4384
4385
4386def _verify_static_batch_size_equality(tensors, columns):
4387  """Verify equality between static batch sizes.
4388
4389  Args:
4390    tensors: iterable of input tensors.
4391    columns: Corresponding feature columns.
4392
4393  Raises:
4394    ValueError: in case of mismatched batch sizes.
4395  """
4396  # bath_size is a tf.Dimension object.
4397  expected_batch_size = None
4398  for i in range(0, len(tensors)):
4399    batch_size = tensor_shape.Dimension(tensor_shape.dimension_value(
4400        tensors[i].shape[0]))
4401    if batch_size.value is not None:
4402      if expected_batch_size is None:
4403        bath_size_column_index = i
4404        expected_batch_size = batch_size
4405      elif not expected_batch_size.is_compatible_with(batch_size):
4406        raise ValueError(
4407            'Batch size (first dimension) of each feature must be same. '
4408            'Batch size of columns ({}, {}): ({}, {})'.format(
4409                columns[bath_size_column_index].name, columns[i].name,
4410                expected_batch_size, batch_size))
4411
4412
4413class SequenceCategoricalColumn(
4414    CategoricalColumn,
4415    fc_old._SequenceCategoricalColumn,  # pylint: disable=protected-access
4416    collections.namedtuple('SequenceCategoricalColumn',
4417                           ('categorical_column'))):
4418  """Represents sequences of categorical data."""
4419
4420  @property
4421  def _is_v2_column(self):
4422    return (isinstance(self.categorical_column, FeatureColumn) and
4423            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4424
4425  @property
4426  def name(self):
4427    """See `FeatureColumn` base class."""
4428    return self.categorical_column.name
4429
4430  @property
4431  def parse_example_spec(self):
4432    """See `FeatureColumn` base class."""
4433    return self.categorical_column.parse_example_spec
4434
4435  @property
4436  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4437                          _FEATURE_COLUMN_DEPRECATION)
4438  def _parse_example_spec(self):
4439    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4440
4441  def transform_feature(self, transformation_cache, state_manager):
4442    """See `FeatureColumn` base class."""
4443    return self.categorical_column.transform_feature(transformation_cache,
4444                                                     state_manager)
4445
4446  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4447                          _FEATURE_COLUMN_DEPRECATION)
4448  def _transform_feature(self, inputs):
4449    return self.categorical_column._transform_feature(inputs)  # pylint: disable=protected-access
4450
4451  @property
4452  def num_buckets(self):
4453    """Returns number of buckets in this sparse feature."""
4454    return self.categorical_column.num_buckets
4455
4456  @property
4457  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4458                          _FEATURE_COLUMN_DEPRECATION)
4459  def _num_buckets(self):
4460    return self.categorical_column._num_buckets  # pylint: disable=protected-access
4461
4462  def _get_sparse_tensors_helper(self, sparse_tensors):
4463    id_tensor = sparse_tensors.id_tensor
4464    weight_tensor = sparse_tensors.weight_tensor
4465    # Expands third dimension, if necessary so that embeddings are not
4466    # combined during embedding lookup. If the tensor is already 3D, leave
4467    # as-is.
4468    shape = array_ops.shape(id_tensor)
4469    # Compute the third dimension explicitly instead of setting it to -1, as
4470    # that doesn't work for dynamically shaped tensors with 0-length at runtime.
4471    # This happens for empty sequences.
4472    target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
4473    id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
4474    if weight_tensor is not None:
4475      weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
4476    return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
4477
4478  def get_sparse_tensors(self, transformation_cache, state_manager):
4479    """Returns an IdWeightPair.
4480
4481    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
4482    weights.
4483
4484    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
4485    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
4486    `SparseTensor` of `float` or `None` to indicate all weights should be
4487    taken to be 1. If specified, `weight_tensor` must have exactly the same
4488    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
4489    output of a `VarLenFeature` which is a ragged matrix.
4490
4491    Args:
4492      transformation_cache: A `FeatureTransformationCache` object to access
4493        features.
4494      state_manager: A `StateManager` to create / access resources such as
4495        lookup tables.
4496    """
4497    sparse_tensors = self.categorical_column.get_sparse_tensors(
4498        transformation_cache, state_manager)
4499    return self._get_sparse_tensors_helper(sparse_tensors)
4500
4501  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4502                          _FEATURE_COLUMN_DEPRECATION)
4503  def _get_sparse_tensors(self, inputs, weight_collections=None,
4504                          trainable=None):
4505    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4506    return self._get_sparse_tensors_helper(sparse_tensors)
4507
4508  @property
4509  def parents(self):
4510    """See 'FeatureColumn` base class."""
4511    return [self.categorical_column]
4512
4513  def _get_config(self):
4514    """See 'FeatureColumn` base class."""
4515    config = dict(zip(self._fields, self))
4516    config['categorical_column'] = serialize_feature_column(
4517        self.categorical_column)
4518    return config
4519
4520  @classmethod
4521  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
4522    """See 'FeatureColumn` base class."""
4523    _check_config_keys(config, cls._fields)
4524    kwargs = config.copy()
4525    kwargs['categorical_column'] = deserialize_feature_column(
4526        config['categorical_column'], custom_objects, columns_by_name)
4527    return cls(**kwargs)
4528
4529
4530# FeatureColumn serialization, deserialization logic.
4531
4532
4533def _check_config_keys(config, expected_keys):
4534  """Checks that a config has all expected_keys."""
4535  if set(config.keys()) != set(expected_keys):
4536    raise ValueError('Invalid config: {}, expected keys: {}'.format(
4537        config, expected_keys))
4538
4539
4540def serialize_feature_column(fc):
4541  """Serializes a FeatureColumn or a raw string key.
4542
4543  This method should only be used to serialize parent FeatureColumns when
4544  implementing FeatureColumn._get_config(), else serialize_feature_columns()
4545  is preferable.
4546
4547  This serialization also keeps information of the FeatureColumn class, so
4548  deserialization is possible without knowing the class type. For example:
4549
4550  a = numeric_column('x')
4551  a._get_config() gives:
4552  {
4553      'key': 'price',
4554      'shape': (1,),
4555      'default_value': None,
4556      'dtype': 'float32',
4557      'normalizer_fn': None
4558  }
4559  While serialize_feature_column(a) gives:
4560  {
4561      'class_name': 'NumericColumn',
4562      'config': {
4563          'key': 'price',
4564          'shape': (1,),
4565          'default_value': None,
4566          'dtype': 'float32',
4567          'normalizer_fn': None
4568      }
4569  }
4570
4571  Args:
4572    fc: A FeatureColumn or raw feature key string.
4573
4574  Returns:
4575    Keras serialization for FeatureColumns, leaves string keys unaffected.
4576
4577  Raises:
4578    ValueError if called with input that is not string or FeatureColumn.
4579  """
4580  if isinstance(fc, six.string_types):
4581    return fc
4582  elif isinstance(fc, FeatureColumn):
4583    return utils.serialize_keras_class_and_config(fc.__class__.__name__,
4584                                                  fc._get_config())
4585  else:
4586    raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
4587
4588
4589def deserialize_feature_column(config,
4590                               custom_objects=None,
4591                               columns_by_name=None):
4592  """Deserializes a `config` generated with `serialize_feature_column`.
4593
4594  This method should only be used to deserialize parent FeatureColumns when
4595  implementing FeatureColumn._from_config(), else deserialize_feature_columns()
4596  is preferable. Returns a FeatureColumn for this config.
4597  TODO(b/118939620): Simplify code if Keras utils support object deduping.
4598
4599  Args:
4600    config: A Dict with the serialization of feature columns acquired by
4601      `serialize_feature_column`, or a string representing a raw column.
4602    custom_objects: A Dict from custom_object name to the associated keras
4603      serializable objects (FeatureColumns, classes or functions).
4604    columns_by_name: A Dict[String, FeatureColumn] of existing columns in order
4605      to avoid duplication.
4606
4607  Raises:
4608    ValueError if `config` has invalid format (e.g: expected keys missing,
4609    or refers to unknown classes).
4610
4611  Returns:
4612    A FeatureColumn corresponding to the input `config`.
4613  """
4614  if isinstance(config, six.string_types):
4615    return config
4616  # A dict from class_name to class for all FeatureColumns in this module.
4617  # FeatureColumns not part of the module can be passed as custom_objects.
4618  module_feature_column_classes = {
4619      cls.__name__: cls for cls in [
4620          BucketizedColumn, EmbeddingColumn, HashedCategoricalColumn,
4621          IdentityCategoricalColumn, IndicatorColumn, NumericColumn,
4622          SequenceCategoricalColumn, SequenceDenseColumn, SharedEmbeddingColumn,
4623          VocabularyFileCategoricalColumn, VocabularyListCategoricalColumn,
4624          WeightedCategoricalColumn, init_ops.TruncatedNormal
4625      ]
4626  }
4627  if columns_by_name is None:
4628    columns_by_name = {}
4629
4630  (cls, cls_config) = utils.class_and_config_for_serialized_keras_object(
4631      config,
4632      module_objects=module_feature_column_classes,
4633      custom_objects=custom_objects,
4634      printable_module_name='feature_column_v2')
4635
4636  if not issubclass(cls, FeatureColumn):
4637    raise ValueError(
4638        'Expected FeatureColumn class, instead found: {}'.format(cls))
4639
4640  # Always deserialize the FeatureColumn, in order to get the name.
4641  new_instance = cls._from_config(  # pylint: disable=protected-access
4642      cls_config,
4643      custom_objects=custom_objects,
4644      columns_by_name=columns_by_name)
4645
4646  # If the name already exists, re-use the column from columns_by_name,
4647  # (new_instance remains unused).
4648  return columns_by_name.setdefault(new_instance.name, new_instance)
4649
4650
4651def serialize_feature_columns(feature_columns):
4652  """Serializes a list of FeatureColumns.
4653
4654  Returns a list of Keras-style config dicts that represent the input
4655  FeatureColumns and can be used with `deserialize_feature_columns` for
4656  reconstructing the original columns.
4657
4658  Args:
4659    feature_columns: A list of FeatureColumns.
4660
4661  Returns:
4662    Keras serialization for the list of FeatureColumns.
4663
4664  Raises:
4665    ValueError if called with input that is not a list of FeatureColumns.
4666  """
4667  return [serialize_feature_column(fc) for fc in feature_columns]
4668
4669
4670def deserialize_feature_columns(configs, custom_objects=None):
4671  """Deserializes a list of FeatureColumns configs.
4672
4673  Returns a list of FeatureColumns given a list of config dicts acquired by
4674  `serialize_feature_columns`.
4675
4676  Args:
4677    configs: A list of Dicts with the serialization of feature columns acquired
4678      by `serialize_feature_columns`.
4679    custom_objects: A Dict from custom_object name to the associated keras
4680      serializable objects (FeatureColumns, classes or functions).
4681
4682  Returns:
4683    FeatureColumn objects corresponding to the input configs.
4684
4685  Raises:
4686    ValueError if called with input that is not a list of FeatureColumns.
4687  """
4688  columns_by_name = {}
4689  return [
4690      deserialize_feature_column(c, custom_objects, columns_by_name)
4691      for c in configs
4692  ]
4693