1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction.
16
17FeatureColumns provide a high level abstraction for ingesting and representing
18features. FeatureColumns are also the primary way of encoding features for
19canned `tf.estimator.Estimator`s.
20
21When using FeatureColumns with `Estimators`, the type of feature column you
22should choose depends on (1) the feature type and (2) the model type.
23
241. Feature type:
25
26  * Continuous features can be represented by `numeric_column`.
27  * Categorical features can be represented by any `categorical_column_with_*`
28  column:
29    - `categorical_column_with_vocabulary_list`
30    - `categorical_column_with_vocabulary_file`
31    - `categorical_column_with_hash_bucket`
32    - `categorical_column_with_identity`
33    - `weighted_categorical_column`
34
352. Model type:
36
37  * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
38
39    Continuous features can be directly fed into deep neural network models.
40
41      age_column = numeric_column("age")
42
43    To feed sparse features into DNN models, wrap the column with
44    `embedding_column` or `indicator_column`. `indicator_column` is recommended
45    for features with only a few possible values. For features with many
46    possible values, to reduce the size of your model, `embedding_column` is
47    recommended.
48
49      embedded_dept_column = embedding_column(
50          categorical_column_with_vocabulary_list(
51              "department", ["math", "philosophy", ...]), dimension=10)
52
53  * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
54
55    Sparse features can be fed directly into linear models. They behave like an
56    indicator column but with an efficient implementation.
57
58      dept_column = categorical_column_with_vocabulary_list("department",
59          ["math", "philosophy", "english"])
60
61    It is recommended that continuous features be bucketized before being
62    fed into linear models.
63
64      bucketized_age_column = bucketized_column(
65          source_column=age_column,
66          boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
67
68    Sparse features can be crossed (also known as conjuncted or combined) in
69    order to form non-linearities, and then fed into linear models.
70
71      cross_dept_age_column = crossed_column(
72          columns=["department", bucketized_age_column],
73          hash_bucket_size=1000)
74
75Example of building canned `Estimator`s using FeatureColumns:
76
77  ```python
78  # Define features and transformations
79  deep_feature_columns = [age_column, embedded_dept_column]
80  wide_feature_columns = [dept_column, bucketized_age_column,
81      cross_dept_age_column]
82
83  # Build deep model
84  estimator = DNNClassifier(
85      feature_columns=deep_feature_columns,
86      hidden_units=[500, 250, 50])
87  estimator.train(...)
88
89  # Or build a wide model
90  estimator = LinearClassifier(
91      feature_columns=wide_feature_columns)
92  estimator.train(...)
93
94  # Or build a wide and deep model!
95  estimator = DNNLinearCombinedClassifier(
96      linear_feature_columns=wide_feature_columns,
97      dnn_feature_columns=deep_feature_columns,
98      dnn_hidden_units=[500, 250, 50])
99  estimator.train(...)
100  ```
101
102
103FeatureColumns can also be transformed into a generic input layer for
104custom models using `input_layer`.
105
106Example of building model using FeatureColumns, this can be used in a
107`model_fn` which is given to the {tf.estimator.Estimator}:
108
109  ```python
110  # Building model via layers
111
112  deep_feature_columns = [age_column, embedded_dept_column]
113  columns_to_tensor = parse_feature_columns_from_examples(
114      serialized=my_data,
115      feature_columns=deep_feature_columns)
116  first_layer = input_layer(
117      features=columns_to_tensor,
118      feature_columns=deep_feature_columns)
119  second_layer = fully_connected(first_layer, ...)
120  ```
121
122NOTE: Functions prefixed with "_" indicate experimental or private parts of
123the API subject to change, and should not be relied upon!
124"""
125
126from __future__ import absolute_import
127from __future__ import division
128from __future__ import print_function
129
130import abc
131import collections
132import math
133import re
134
135import numpy as np
136import six
137
138from tensorflow.python.eager import context
139from tensorflow.python.feature_column import feature_column as fc_old
140from tensorflow.python.feature_column import utils as fc_utils
141from tensorflow.python.framework import dtypes
142from tensorflow.python.framework import ops
143from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
144from tensorflow.python.framework import tensor_shape
145from tensorflow.python.ops import array_ops
146from tensorflow.python.ops import check_ops
147from tensorflow.python.ops import control_flow_ops
148from tensorflow.python.ops import embedding_ops
149from tensorflow.python.ops import init_ops
150from tensorflow.python.ops import lookup_ops
151from tensorflow.python.ops import math_ops
152from tensorflow.python.ops import parsing_ops
153from tensorflow.python.ops import sparse_ops
154from tensorflow.python.ops import string_ops
155from tensorflow.python.ops import variable_scope
156from tensorflow.python.ops import variables
157from tensorflow.python.platform import gfile
158from tensorflow.python.platform import tf_logging as logging
159from tensorflow.python.training import checkpoint_utils
160from tensorflow.python.training.tracking import base as trackable
161from tensorflow.python.training.tracking import data_structures
162from tensorflow.python.training.tracking import tracking
163from tensorflow.python.util import deprecation
164from tensorflow.python.util import nest
165from tensorflow.python.util import tf_inspect
166from tensorflow.python.util.compat import collections_abc
167from tensorflow.python.util.tf_export import tf_export
168
169
170_FEATURE_COLUMN_DEPRECATION_DATE = None
171_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
172                               'deprecated. Please use the new FeatureColumn '
173                               'APIs instead.')
174
175
176class StateManager(object):
177  """Manages the state associated with FeatureColumns.
178
179  Some `FeatureColumn`s create variables or resources to assist their
180  computation. The `StateManager` is responsible for creating and storing these
181  objects since `FeatureColumn`s are supposed to be stateless configuration
182  only.
183  """
184
185  def create_variable(self,
186                      feature_column,
187                      name,
188                      shape,
189                      dtype=None,
190                      trainable=True,
191                      use_resource=True,
192                      initializer=None):
193    """Creates a new variable.
194
195    Args:
196      feature_column: A `FeatureColumn` object this variable corresponds to.
197      name: variable name.
198      shape: variable shape.
199      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
200      trainable: Whether this variable is trainable or not.
201      use_resource: If true, we use resource variables. Otherwise we use
202        RefVariable.
203      initializer: initializer instance (callable).
204
205    Returns:
206      The created variable.
207    """
208    del feature_column, name, shape, dtype, trainable, use_resource, initializer
209    raise NotImplementedError('StateManager.create_variable')
210
211  def add_variable(self, feature_column, var):
212    """Adds an existing variable to the state.
213
214    Args:
215      feature_column: A `FeatureColumn` object to associate this variable with.
216      var: The variable.
217    """
218    del feature_column, var
219    raise NotImplementedError('StateManager.add_variable')
220
221  def get_variable(self, feature_column, name):
222    """Returns an existing variable.
223
224    Args:
225      feature_column: A `FeatureColumn` object this variable corresponds to.
226      name: variable name.
227    """
228    del feature_column, name
229    raise NotImplementedError('StateManager.get_var')
230
231  def add_resource(self, feature_column, name, resource):
232    """Creates a new resource.
233
234    Resources can be things such as tables, variables, trackables, etc.
235
236    Args:
237      feature_column: A `FeatureColumn` object this resource corresponds to.
238      name: Name of the resource.
239      resource: The resource.
240
241    Returns:
242      The created resource.
243    """
244    del feature_column, name, resource
245    raise NotImplementedError('StateManager.add_resource')
246
247  def has_resource(self, feature_column, name):
248    """Returns true iff a resource with same name exists.
249
250    Resources can be things such as tables, variables, trackables, etc.
251
252    Args:
253      feature_column: A `FeatureColumn` object this variable corresponds to.
254      name: Name of the resource.
255    """
256    del feature_column, name
257    raise NotImplementedError('StateManager.has_resource')
258
259  def get_resource(self, feature_column, name):
260    """Returns an already created resource.
261
262    Resources can be things such as tables, variables, trackables, etc.
263
264    Args:
265      feature_column: A `FeatureColumn` object this variable corresponds to.
266      name: Name of the resource.
267    """
268    del feature_column, name
269    raise NotImplementedError('StateManager.get_resource')
270
271
272class _StateManagerImpl(StateManager):
273  """Manages the state of DenseFeatures and LinearLayer."""
274
275  def __init__(self, layer, trainable):
276    """Creates an _StateManagerImpl object.
277
278    Args:
279      layer: The input layer this state manager is associated with.
280      trainable: Whether by default, variables created are trainable or not.
281    """
282    self._trainable = trainable
283    self._layer = layer
284    if self._layer is not None and not hasattr(self._layer, '_resources'):
285      self._layer._resources = data_structures.Mapping()  # pylint: disable=protected-access
286    self._cols_to_vars_map = collections.defaultdict(lambda: {})
287    self._cols_to_resources_map = collections.defaultdict(lambda: {})
288
289  def create_variable(self,
290                      feature_column,
291                      name,
292                      shape,
293                      dtype=None,
294                      trainable=True,
295                      use_resource=True,
296                      initializer=None):
297    if name in self._cols_to_vars_map[feature_column]:
298      raise ValueError('Variable already exists.')
299
300    # We explicitly track these variables since `name` is not guaranteed to be
301    # unique and disable manual tracking that the add_weight call does.
302    with trackable.no_manual_dependency_tracking_scope(self._layer):
303      var = self._layer.add_weight(
304          name=name,
305          shape=shape,
306          dtype=dtype,
307          initializer=initializer,
308          trainable=self._trainable and trainable,
309          use_resource=use_resource,
310          # TODO(rohanj): Get rid of this hack once we have a mechanism for
311          # specifying a default partitioner for an entire layer. In that case,
312          # the default getter for Layers should work.
313          getter=variable_scope.get_variable)
314    if isinstance(var, variables.PartitionedVariable):
315      for v in var:
316        part_name = name + '/' + str(v._get_save_slice_info().var_offset[0])  # pylint: disable=protected-access
317        self._layer._track_trackable(v, feature_column.name + '/' + part_name)  # pylint: disable=protected-access
318    else:
319      if isinstance(var, trackable.Trackable):
320        self._layer._track_trackable(var, feature_column.name + '/' + name)  # pylint: disable=protected-access
321
322    self._cols_to_vars_map[feature_column][name] = var
323    return var
324
325  def get_variable(self, feature_column, name):
326    if name in self._cols_to_vars_map[feature_column]:
327      return self._cols_to_vars_map[feature_column][name]
328    raise ValueError('Variable does not exist.')
329
330  def add_resource(self, feature_column, resource_name, resource):
331    self._cols_to_resources_map[feature_column][resource_name] = resource
332    # pylint: disable=protected-access
333    if self._layer is not None and isinstance(resource, trackable.Trackable):
334      # Add trackable resources to the layer for serialization.
335      if feature_column.name not in self._layer._resources:
336        self._layer._resources[feature_column.name] = data_structures.Mapping()
337      if resource_name not in self._layer._resources[feature_column.name]:
338        self._layer._resources[feature_column.name][resource_name] = resource
339    # pylint: enable=protected-access
340
341  def has_resource(self, feature_column, resource_name):
342    return resource_name in self._cols_to_resources_map[feature_column]
343
344  def get_resource(self, feature_column, resource_name):
345    if (feature_column not in self._cols_to_resources_map or
346        resource_name not in self._cols_to_resources_map[feature_column]):
347      raise ValueError('Resource does not exist.')
348    return self._cols_to_resources_map[feature_column][resource_name]
349
350
351class _StateManagerImplV2(_StateManagerImpl):
352  """Manages the state of DenseFeatures."""
353
354  def create_variable(self,
355                      feature_column,
356                      name,
357                      shape,
358                      dtype=None,
359                      trainable=True,
360                      use_resource=True,
361                      initializer=None):
362    if name in self._cols_to_vars_map[feature_column]:
363      raise ValueError('Variable already exists.')
364
365    # We explicitly track these variables since `name` is not guaranteed to be
366    # unique and disable manual tracking that the add_weight call does.
367    with trackable.no_manual_dependency_tracking_scope(self._layer):
368      var = self._layer.add_weight(
369          name=name,
370          shape=shape,
371          dtype=dtype,
372          initializer=initializer,
373          trainable=self._trainable and trainable,
374          use_resource=use_resource)
375    if isinstance(var, trackable.Trackable):
376      self._layer._track_trackable(var, feature_column.name + '/' + name)  # pylint: disable=protected-access
377    self._cols_to_vars_map[feature_column][name] = var
378    return var
379
380
381def _transform_features_v2(features, feature_columns, state_manager):
382  """Returns transformed features based on features columns passed in.
383
384  Please note that most probably you would not need to use this function. Please
385  check `input_layer` and `linear_model` to see whether they will
386  satisfy your use case or not.
387
388  Example:
389
390  ```python
391  # Define features and transformations
392  crosses_a_x_b = crossed_column(
393      columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
394  price_buckets = bucketized_column(
395      source_column=numeric_column("price"), boundaries=[...])
396
397  columns = [crosses_a_x_b, price_buckets]
398  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
399  transformed = transform_features(features=features, feature_columns=columns)
400
401  assertCountEqual(columns, transformed.keys())
402  ```
403
404  Args:
405    features: A mapping from key to tensors. `FeatureColumn`s look up via these
406      keys. For example `numeric_column('price')` will look at 'price' key in
407      this dict. Values can be a `SparseTensor` or a `Tensor` depends on
408      corresponding `FeatureColumn`.
409    feature_columns: An iterable containing all the `FeatureColumn`s.
410    state_manager: A StateManager object that holds the FeatureColumn state.
411
412  Returns:
413    A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values.
414  """
415  feature_columns = _normalize_feature_columns(feature_columns)
416  outputs = {}
417  with ops.name_scope(
418      None, default_name='transform_features', values=features.values()):
419    transformation_cache = FeatureTransformationCache(features)
420    for column in feature_columns:
421      with ops.name_scope(
422          None,
423          default_name=_sanitize_column_name_for_variable_scope(column.name)):
424        outputs[column] = transformation_cache.get(column, state_manager)
425  return outputs
426
427
428@tf_export('feature_column.make_parse_example_spec', v1=[])
429def make_parse_example_spec_v2(feature_columns):
430  """Creates parsing spec dictionary from input feature_columns.
431
432  The returned dictionary can be used as arg 'features' in
433  `tf.io.parse_example`.
434
435  Typical usage example:
436
437  ```python
438  # Define features and transformations
439  feature_a = tf.feature_column.categorical_column_with_vocabulary_file(...)
440  feature_b = tf.feature_column.numeric_column(...)
441  feature_c_bucketized = tf.feature_column.bucketized_column(
442      tf.feature_column.numeric_column("feature_c"), ...)
443  feature_a_x_feature_c = tf.feature_column.crossed_column(
444      columns=["feature_a", feature_c_bucketized], ...)
445
446  feature_columns = set(
447      [feature_b, feature_c_bucketized, feature_a_x_feature_c])
448  features = tf.io.parse_example(
449      serialized=serialized_examples,
450      features=tf.feature_column.make_parse_example_spec(feature_columns))
451  ```
452
453  For the above example, make_parse_example_spec would return the dict:
454
455  ```python
456  {
457      "feature_a": parsing_ops.VarLenFeature(tf.string),
458      "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32),
459      "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32)
460  }
461  ```
462
463  Args:
464    feature_columns: An iterable containing all feature columns. All items
465      should be instances of classes derived from `FeatureColumn`.
466
467  Returns:
468    A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`
469    value.
470
471  Raises:
472    ValueError: If any of the given `feature_columns` is not a `FeatureColumn`
473      instance.
474  """
475  result = {}
476  for column in feature_columns:
477    if not isinstance(column, FeatureColumn):
478      raise ValueError('All feature_columns must be FeatureColumn instances. '
479                       'Given: {}'.format(column))
480    config = column.parse_example_spec
481    for key, value in six.iteritems(config):
482      if key in result and value != result[key]:
483        raise ValueError(
484            'feature_columns contain different parse_spec for key '
485            '{}. Given {} and {}'.format(key, value, result[key]))
486    result.update(config)
487  return result
488
489
490@tf_export('feature_column.embedding_column')
491def embedding_column(categorical_column,
492                     dimension,
493                     combiner='mean',
494                     initializer=None,
495                     ckpt_to_load_from=None,
496                     tensor_name_in_ckpt=None,
497                     max_norm=None,
498                     trainable=True,
499                     use_safe_embedding_lookup=True):
500  """`DenseColumn` that converts from sparse, categorical input.
501
502  Use this when your inputs are sparse, but you want to convert them to a dense
503  representation (e.g., to feed to a DNN).
504
505  Inputs must be a `CategoricalColumn` created by any of the
506  `categorical_column_*` function. Here is an example of using
507  `embedding_column` with `DNNClassifier`:
508
509  ```python
510  video_id = categorical_column_with_identity(
511      key='video_id', num_buckets=1000000, default_value=0)
512  columns = [embedding_column(video_id, 9),...]
513
514  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
515
516  label_column = ...
517  def input_fn():
518    features = tf.io.parse_example(
519        ..., features=make_parse_example_spec(columns + [label_column]))
520    labels = features.pop(label_column.name)
521    return features, labels
522
523  estimator.train(input_fn=input_fn, steps=100)
524  ```
525
526  Here is an example using `embedding_column` with model_fn:
527
528  ```python
529  def model_fn(features, ...):
530    video_id = categorical_column_with_identity(
531        key='video_id', num_buckets=1000000, default_value=0)
532    columns = [embedding_column(video_id, 9),...]
533    dense_tensor = input_layer(features, columns)
534    # Form DNN layers, calculate loss, and return EstimatorSpec.
535    ...
536  ```
537
538  Args:
539    categorical_column: A `CategoricalColumn` created by a
540      `categorical_column_with_*` function. This column produces the sparse IDs
541      that are inputs to the embedding lookup.
542    dimension: An integer specifying dimension of the embedding, must be > 0.
543    combiner: A string specifying how to reduce if there are multiple entries in
544      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
545      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
546      with bag-of-words columns. Each of this can be thought as example level
547      normalizations on the column. For more information, see
548      `tf.embedding_lookup_sparse`.
549    initializer: A variable initializer function to be used in embedding
550      variable initialization. If not specified, defaults to
551      `truncated_normal_initializer` with mean `0.0` and
552      standard deviation `1/sqrt(dimension)`.
553    ckpt_to_load_from: String representing checkpoint name/pattern from which to
554      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
555    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
556      to restore the column weights. Required if `ckpt_to_load_from` is not
557      `None`.
558    max_norm: If not `None`, embedding values are l2-normalized to this value.
559    trainable: Whether or not the embedding is trainable. Default is True.
560    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
561      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
562      there are no empty rows and all weights and ids are positive at the
563      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
564      input tensors. Defaults to true, consider turning off if the above checks
565      are not needed. Note that having empty rows will not trigger any error
566      though the output result might be 0 or omitted.
567
568  Returns:
569    `DenseColumn` that converts from sparse input.
570
571  Raises:
572    ValueError: if `dimension` not > 0.
573    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
574      is specified.
575    ValueError: if `initializer` is specified and is not callable.
576    RuntimeError: If eager execution is enabled.
577  """
578  if (dimension is None) or (dimension < 1):
579    raise ValueError('Invalid dimension {}.'.format(dimension))
580  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
581    raise ValueError('Must specify both `ckpt_to_load_from` and '
582                     '`tensor_name_in_ckpt` or none of them.')
583
584  if (initializer is not None) and (not callable(initializer)):
585    raise ValueError('initializer must be callable if specified. '
586                     'Embedding of column_name: {}'.format(
587                         categorical_column.name))
588  if initializer is None:
589    initializer = init_ops.truncated_normal_initializer(
590        mean=0.0, stddev=1 / math.sqrt(dimension))
591
592  return EmbeddingColumn(
593      categorical_column=categorical_column,
594      dimension=dimension,
595      combiner=combiner,
596      initializer=initializer,
597      ckpt_to_load_from=ckpt_to_load_from,
598      tensor_name_in_ckpt=tensor_name_in_ckpt,
599      max_norm=max_norm,
600      trainable=trainable,
601      use_safe_embedding_lookup=use_safe_embedding_lookup)
602
603
604@tf_export(v1=['feature_column.shared_embedding_columns'])
605def shared_embedding_columns(categorical_columns,
606                             dimension,
607                             combiner='mean',
608                             initializer=None,
609                             shared_embedding_collection_name=None,
610                             ckpt_to_load_from=None,
611                             tensor_name_in_ckpt=None,
612                             max_norm=None,
613                             trainable=True,
614                             use_safe_embedding_lookup=True):
615  """List of dense columns that convert from sparse, categorical input.
616
617  This is similar to `embedding_column`, except that it produces a list of
618  embedding columns that share the same embedding weights.
619
620  Use this when your inputs are sparse and of the same type (e.g. watched and
621  impression video IDs that share the same vocabulary), and you want to convert
622  them to a dense representation (e.g., to feed to a DNN).
623
624  Inputs must be a list of categorical columns created by any of the
625  `categorical_column_*` function. They must all be of the same type and have
626  the same arguments except `key`. E.g. they can be
627  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
628  all columns could also be weighted_categorical_column.
629
630  Here is an example embedding of two features for a DNNClassifier model:
631
632  ```python
633  watched_video_id = categorical_column_with_vocabulary_file(
634      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
635  impression_video_id = categorical_column_with_vocabulary_file(
636      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
637  columns = shared_embedding_columns(
638      [watched_video_id, impression_video_id], dimension=10)
639
640  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
641
642  label_column = ...
643  def input_fn():
644    features = tf.io.parse_example(
645        ..., features=make_parse_example_spec(columns + [label_column]))
646    labels = features.pop(label_column.name)
647    return features, labels
648
649  estimator.train(input_fn=input_fn, steps=100)
650  ```
651
652  Here is an example using `shared_embedding_columns` with model_fn:
653
654  ```python
655  def model_fn(features, ...):
656    watched_video_id = categorical_column_with_vocabulary_file(
657        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
658    impression_video_id = categorical_column_with_vocabulary_file(
659        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
660    columns = shared_embedding_columns(
661        [watched_video_id, impression_video_id], dimension=10)
662    dense_tensor = input_layer(features, columns)
663    # Form DNN layers, calculate loss, and return EstimatorSpec.
664    ...
665  ```
666
667  Args:
668    categorical_columns: List of categorical columns created by a
669      `categorical_column_with_*` function. These columns produce the sparse IDs
670      that are inputs to the embedding lookup. All columns must be of the same
671      type and have the same arguments except `key`. E.g. they can be
672      categorical_column_with_vocabulary_file with the same vocabulary_file.
673      Some or all columns could also be weighted_categorical_column.
674    dimension: An integer specifying dimension of the embedding, must be > 0.
675    combiner: A string specifying how to reduce if there are multiple entries in
676      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
677      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
678      with bag-of-words columns. Each of this can be thought as example level
679      normalizations on the column. For more information, see
680      `tf.embedding_lookup_sparse`.
681    initializer: A variable initializer function to be used in embedding
682      variable initialization. If not specified, defaults to
683      `truncated_normal_initializer` with mean `0.0` and
684      standard deviation `1/sqrt(dimension)`.
685    shared_embedding_collection_name: Optional name of the collection where
686      shared embedding weights are added. If not given, a reasonable name will
687      be chosen based on the names of `categorical_columns`. This is also used
688      in `variable_scope` when creating shared embedding weights.
689    ckpt_to_load_from: String representing checkpoint name/pattern from which to
690      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
691    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
692      to restore the column weights. Required if `ckpt_to_load_from` is not
693      `None`.
694    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
695      than this value, before combining.
696    trainable: Whether or not the embedding is trainable. Default is True.
697    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
698      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
699      there are no empty rows and all weights and ids are positive at the
700      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
701      input tensors. Defaults to true, consider turning off if the above checks
702      are not needed. Note that having empty rows will not trigger any error
703      though the output result might be 0 or omitted.
704
705  Returns:
706    A list of dense columns that converts from sparse input. The order of
707    results follows the ordering of `categorical_columns`.
708
709  Raises:
710    ValueError: if `dimension` not > 0.
711    ValueError: if any of the given `categorical_columns` is of different type
712      or has different arguments than the others.
713    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
714      is specified.
715    ValueError: if `initializer` is specified and is not callable.
716    RuntimeError: if eager execution is enabled.
717  """
718  if context.executing_eagerly():
719    raise RuntimeError('shared_embedding_columns are not supported when eager '
720                       'execution is enabled.')
721
722  if (dimension is None) or (dimension < 1):
723    raise ValueError('Invalid dimension {}.'.format(dimension))
724  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
725    raise ValueError('Must specify both `ckpt_to_load_from` and '
726                     '`tensor_name_in_ckpt` or none of them.')
727
728  if (initializer is not None) and (not callable(initializer)):
729    raise ValueError('initializer must be callable if specified.')
730  if initializer is None:
731    initializer = init_ops.truncated_normal_initializer(
732        mean=0.0, stddev=1. / math.sqrt(dimension))
733
734  # Sort the columns so the default collection name is deterministic even if the
735  # user passes columns from an unsorted collection, such as dict.values().
736  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
737
738  c0 = sorted_columns[0]
739  num_buckets = c0._num_buckets  # pylint: disable=protected-access
740  if not isinstance(c0, fc_old._CategoricalColumn):  # pylint: disable=protected-access
741    raise ValueError(
742        'All categorical_columns must be subclasses of _CategoricalColumn. '
743        'Given: {}, of type: {}'.format(c0, type(c0)))
744  while isinstance(
745      c0, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn,  # pylint: disable=protected-access
746           fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)):  # pylint: disable=protected-access
747    c0 = c0.categorical_column
748  for c in sorted_columns[1:]:
749    while isinstance(
750        c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn,  # pylint: disable=protected-access
751            fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)):  # pylint: disable=protected-access
752      c = c.categorical_column
753    if not isinstance(c, type(c0)):
754      raise ValueError(
755          'To use shared_embedding_column, all categorical_columns must have '
756          'the same type, or be weighted_categorical_column or sequence column '
757          'of the same type. Given column: {} of type: {} does not match given '
758          'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
759    if num_buckets != c._num_buckets:  # pylint: disable=protected-access
760      raise ValueError(
761          'To use shared_embedding_column, all categorical_columns must have '
762          'the same number of buckets. Given column: {} with buckets: {} does  '
763          'not match column: {} with buckets: {}'.format(
764              c0, num_buckets, c, c._num_buckets))  # pylint: disable=protected-access
765
766  if not shared_embedding_collection_name:
767    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
768    shared_embedding_collection_name += '_shared_embedding'
769
770  result = []
771  for column in categorical_columns:
772    result.append(
773        fc_old._SharedEmbeddingColumn(  # pylint: disable=protected-access
774            categorical_column=column,
775            initializer=initializer,
776            dimension=dimension,
777            combiner=combiner,
778            shared_embedding_collection_name=shared_embedding_collection_name,
779            ckpt_to_load_from=ckpt_to_load_from,
780            tensor_name_in_ckpt=tensor_name_in_ckpt,
781            max_norm=max_norm,
782            trainable=trainable,
783            use_safe_embedding_lookup=use_safe_embedding_lookup))
784
785  return result
786
787
788@tf_export('feature_column.shared_embeddings', v1=[])
789def shared_embedding_columns_v2(categorical_columns,
790                                dimension,
791                                combiner='mean',
792                                initializer=None,
793                                shared_embedding_collection_name=None,
794                                ckpt_to_load_from=None,
795                                tensor_name_in_ckpt=None,
796                                max_norm=None,
797                                trainable=True,
798                                use_safe_embedding_lookup=True):
799  """List of dense columns that convert from sparse, categorical input.
800
801  This is similar to `embedding_column`, except that it produces a list of
802  embedding columns that share the same embedding weights.
803
804  Use this when your inputs are sparse and of the same type (e.g. watched and
805  impression video IDs that share the same vocabulary), and you want to convert
806  them to a dense representation (e.g., to feed to a DNN).
807
808  Inputs must be a list of categorical columns created by any of the
809  `categorical_column_*` function. They must all be of the same type and have
810  the same arguments except `key`. E.g. they can be
811  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
812  all columns could also be weighted_categorical_column.
813
814  Here is an example embedding of two features for a DNNClassifier model:
815
816  ```python
817  watched_video_id = categorical_column_with_vocabulary_file(
818      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
819  impression_video_id = categorical_column_with_vocabulary_file(
820      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
821  columns = shared_embedding_columns(
822      [watched_video_id, impression_video_id], dimension=10)
823
824  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
825
826  label_column = ...
827  def input_fn():
828    features = tf.io.parse_example(
829        ..., features=make_parse_example_spec(columns + [label_column]))
830    labels = features.pop(label_column.name)
831    return features, labels
832
833  estimator.train(input_fn=input_fn, steps=100)
834  ```
835
836  Here is an example using `shared_embedding_columns` with model_fn:
837
838  ```python
839  def model_fn(features, ...):
840    watched_video_id = categorical_column_with_vocabulary_file(
841        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
842    impression_video_id = categorical_column_with_vocabulary_file(
843        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
844    columns = shared_embedding_columns(
845        [watched_video_id, impression_video_id], dimension=10)
846    dense_tensor = input_layer(features, columns)
847    # Form DNN layers, calculate loss, and return EstimatorSpec.
848    ...
849  ```
850
851  Args:
852    categorical_columns: List of categorical columns created by a
853      `categorical_column_with_*` function. These columns produce the sparse IDs
854      that are inputs to the embedding lookup. All columns must be of the same
855      type and have the same arguments except `key`. E.g. they can be
856      categorical_column_with_vocabulary_file with the same vocabulary_file.
857      Some or all columns could also be weighted_categorical_column.
858    dimension: An integer specifying dimension of the embedding, must be > 0.
859    combiner: A string specifying how to reduce if there are multiple entries
860      in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
861      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
862      with bag-of-words columns. Each of this can be thought as example level
863      normalizations on the column. For more information, see
864      `tf.embedding_lookup_sparse`.
865    initializer: A variable initializer function to be used in embedding
866      variable initialization. If not specified, defaults to
867      `truncated_normal_initializer` with mean `0.0` and standard
868      deviation `1/sqrt(dimension)`.
869    shared_embedding_collection_name: Optional collective name of these columns.
870      If not given, a reasonable name will be chosen based on the names of
871      `categorical_columns`.
872    ckpt_to_load_from: String representing checkpoint name/pattern from which to
873      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
874    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
875      which to restore the column weights. Required if `ckpt_to_load_from` is
876      not `None`.
877    max_norm: If not `None`, each embedding is clipped if its l2-norm is
878      larger than this value, before combining.
879    trainable: Whether or not the embedding is trainable. Default is True.
880    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
881      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
882      there are no empty rows and all weights and ids are positive at the
883      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
884      input tensors. Defaults to true, consider turning off if the above checks
885      are not needed. Note that having empty rows will not trigger any error
886      though the output result might be 0 or omitted.
887
888  Returns:
889    A list of dense columns that converts from sparse input. The order of
890    results follows the ordering of `categorical_columns`.
891
892  Raises:
893    ValueError: if `dimension` not > 0.
894    ValueError: if any of the given `categorical_columns` is of different type
895      or has different arguments than the others.
896    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
897      is specified.
898    ValueError: if `initializer` is specified and is not callable.
899    RuntimeError: if eager execution is enabled.
900  """
901  if context.executing_eagerly():
902    raise RuntimeError('shared_embedding_columns are not supported when eager '
903                       'execution is enabled.')
904
905  if (dimension is None) or (dimension < 1):
906    raise ValueError('Invalid dimension {}.'.format(dimension))
907  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
908    raise ValueError('Must specify both `ckpt_to_load_from` and '
909                     '`tensor_name_in_ckpt` or none of them.')
910
911  if (initializer is not None) and (not callable(initializer)):
912    raise ValueError('initializer must be callable if specified.')
913  if initializer is None:
914    initializer = init_ops.truncated_normal_initializer(
915        mean=0.0, stddev=1. / math.sqrt(dimension))
916
917  # Sort the columns so the default collection name is deterministic even if the
918  # user passes columns from an unsorted collection, such as dict.values().
919  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
920
921  c0 = sorted_columns[0]
922  num_buckets = c0.num_buckets
923  if not isinstance(c0, CategoricalColumn):
924    raise ValueError(
925        'All categorical_columns must be subclasses of CategoricalColumn. '
926        'Given: {}, of type: {}'.format(c0, type(c0)))
927  while isinstance(c0, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
928    c0 = c0.categorical_column
929  for c in sorted_columns[1:]:
930    while isinstance(c, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
931      c = c.categorical_column
932    if not isinstance(c, type(c0)):
933      raise ValueError(
934          'To use shared_embedding_column, all categorical_columns must have '
935          'the same type, or be weighted_categorical_column or sequence column '
936          'of the same type. Given column: {} of type: {} does not match given '
937          'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
938    if num_buckets != c.num_buckets:
939      raise ValueError(
940          'To use shared_embedding_column, all categorical_columns must have '
941          'the same number of buckets. Given column: {} with buckets: {} does  '
942          'not match column: {} with buckets: {}'.format(
943              c0, num_buckets, c, c.num_buckets))
944
945  if not shared_embedding_collection_name:
946    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
947    shared_embedding_collection_name += '_shared_embedding'
948
949  column_creator = SharedEmbeddingColumnCreator(
950      dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt,
951      num_buckets, trainable, shared_embedding_collection_name,
952      use_safe_embedding_lookup)
953
954  result = []
955  for column in categorical_columns:
956    result.append(
957        column_creator(
958            categorical_column=column, combiner=combiner, max_norm=max_norm))
959
960  return result
961
962
963@tf_export('feature_column.numeric_column')
964def numeric_column(key,
965                   shape=(1,),
966                   default_value=None,
967                   dtype=dtypes.float32,
968                   normalizer_fn=None):
969  """Represents real valued or numerical features.
970
971  Example:
972
973  ```python
974  price = numeric_column('price')
975  columns = [price, ...]
976  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
977  dense_tensor = input_layer(features, columns)
978
979  # or
980  bucketized_price = bucketized_column(price, boundaries=[...])
981  columns = [bucketized_price, ...]
982  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
983  linear_prediction = linear_model(features, columns)
984  ```
985
986  Args:
987    key: A unique string identifying the input feature. It is used as the
988      column name and the dictionary key for feature parsing configs, feature
989      `Tensor` objects, and feature columns.
990    shape: An iterable of integers specifies the shape of the `Tensor`. An
991      integer can be given which means a single dimension `Tensor` with given
992      width. The `Tensor` representing the column will have the shape of
993      [batch_size] + `shape`.
994    default_value: A single value compatible with `dtype` or an iterable of
995      values compatible with `dtype` which the column takes on during
996      `tf.Example` parsing if data is missing. A default value of `None` will
997      cause `tf.io.parse_example` to fail if an example does not contain this
998      column. If a single value is provided, the same value will be applied as
999      the default value for every item. If an iterable of values is provided,
1000      the shape of the `default_value` should be equal to the given `shape`.
1001    dtype: defines the type of values. Default value is `tf.float32`. Must be a
1002      non-quantized, real integer or floating point type.
1003    normalizer_fn: If not `None`, a function that can be used to normalize the
1004      value of the tensor after `default_value` is applied for parsing.
1005      Normalizer function takes the input `Tensor` as its argument, and returns
1006      the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
1007      even though the most common use case of this function is normalization, it
1008      can be used for any kind of Tensorflow transformations.
1009
1010  Returns:
1011    A `NumericColumn`.
1012
1013  Raises:
1014    TypeError: if any dimension in shape is not an int
1015    ValueError: if any dimension in shape is not a positive integer
1016    TypeError: if `default_value` is an iterable but not compatible with `shape`
1017    TypeError: if `default_value` is not compatible with `dtype`.
1018    ValueError: if `dtype` is not convertible to `tf.float32`.
1019  """
1020  shape = _check_shape(shape, key)
1021  if not (dtype.is_integer or dtype.is_floating):
1022    raise ValueError('dtype must be convertible to float. '
1023                     'dtype: {}, key: {}'.format(dtype, key))
1024  default_value = fc_utils.check_default_value(
1025      shape, default_value, dtype, key)
1026
1027  if normalizer_fn is not None and not callable(normalizer_fn):
1028    raise TypeError(
1029        'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
1030
1031  fc_utils.assert_key_is_string(key)
1032  return NumericColumn(
1033      key,
1034      shape=shape,
1035      default_value=default_value,
1036      dtype=dtype,
1037      normalizer_fn=normalizer_fn)
1038
1039
1040@tf_export('feature_column.bucketized_column')
1041def bucketized_column(source_column, boundaries):
1042  """Represents discretized dense input bucketed by `boundaries`.
1043
1044  Buckets include the left boundary, and exclude the right boundary. Namely,
1045  `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
1046  `[1., 2.)`, and `[2., +inf)`.
1047
1048  For example, if the inputs are
1049
1050  ```python
1051  boundaries = [0, 10, 100]
1052  input tensor = [[-5, 10000]
1053                  [150,   10]
1054                  [5,    100]]
1055  ```
1056
1057  then the output will be
1058
1059  ```python
1060  output = [[0, 3]
1061            [3, 2]
1062            [1, 3]]
1063  ```
1064
1065  Example:
1066
1067  ```python
1068  price = tf.feature_column.numeric_column('price')
1069  bucketized_price = tf.feature_column.bucketized_column(
1070      price, boundaries=[...])
1071  columns = [bucketized_price, ...]
1072  features = tf.io.parse_example(
1073      ..., features=tf.feature_column.make_parse_example_spec(columns))
1074  dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1075  ```
1076
1077  A `bucketized_column` can also be crossed with another categorical column
1078  using `crossed_column`:
1079
1080  ```python
1081  price = tf.feature_column.numeric_column('price')
1082  # bucketized_column converts numerical feature to a categorical one.
1083  bucketized_price = tf.feature_column.bucketized_column(
1084      price, boundaries=[...])
1085  # 'keywords' is a string feature.
1086  price_x_keywords = tf.feature_column.crossed_column(
1087      [bucketized_price, 'keywords'], 50K)
1088  columns = [price_x_keywords, ...]
1089  features = tf.io.parse_example(
1090      ..., features=tf.feature_column.make_parse_example_spec(columns))
1091  dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1092  linear_model = tf.keras.experimental.LinearModel(units=...)(dense_tensor)
1093  ```
1094
1095  Args:
1096    source_column: A one-dimensional dense column which is generated with
1097      `numeric_column`.
1098    boundaries: A sorted list or tuple of floats specifying the boundaries.
1099
1100  Returns:
1101    A `BucketizedColumn`.
1102
1103  Raises:
1104    ValueError: If `source_column` is not a numeric column, or if it is not
1105      one-dimensional.
1106    ValueError: If `boundaries` is not a sorted list or tuple.
1107  """
1108  if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)):  # pylint: disable=protected-access
1109    raise ValueError(
1110        'source_column must be a column generated with numeric_column(). '
1111        'Given: {}'.format(source_column))
1112  if len(source_column.shape) > 1:
1113    raise ValueError(
1114        'source_column must be one-dimensional column. '
1115        'Given: {}'.format(source_column))
1116  if not boundaries:
1117    raise ValueError('boundaries must not be empty.')
1118  if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)):
1119    raise ValueError('boundaries must be a sorted list.')
1120  for i in range(len(boundaries) - 1):
1121    if boundaries[i] >= boundaries[i + 1]:
1122      raise ValueError('boundaries must be a sorted list.')
1123  return BucketizedColumn(source_column, tuple(boundaries))
1124
1125
1126@tf_export('feature_column.categorical_column_with_hash_bucket')
1127def categorical_column_with_hash_bucket(key,
1128                                        hash_bucket_size,
1129                                        dtype=dtypes.string):
1130  """Represents sparse feature where ids are set by hashing.
1131
1132  Use this when your sparse features are in string or integer format, and you
1133  want to distribute your inputs into a finite number of buckets by hashing.
1134  output_id = Hash(input_feature_string) % bucket_size for string type input.
1135  For int type input, the value is converted to its string representation first
1136  and then hashed by the same formula.
1137
1138  For input dictionary `features`, `features[key]` is either `Tensor` or
1139  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1140  and `''` for string, which will be dropped by this feature column.
1141
1142  Example:
1143
1144  ```python
1145  import tensorflow as tf
1146  keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords",
1147  10000)
1148  columns = [keywords]
1149  features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM',
1150  'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow',
1151  'LSTM', 'Keras', 'RNN']])}
1152  linear_prediction, _, _ = tf.compat.v1.feature_column.linear_model(features,
1153  columns)
1154
1155  # or
1156  import tensorflow as tf
1157  keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords",
1158  10000)
1159  keywords_embedded = tf.feature_column.embedding_column(keywords, 16)
1160  columns = [keywords_embedded]
1161  features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM',
1162  'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow',
1163  'LSTM', 'Keras', 'RNN']])}
1164  input_layer = tf.keras.layers.DenseFeatures(columns)
1165  dense_tensor = input_layer(features)
1166  ```
1167
1168  Args:
1169    key: A unique string identifying the input feature. It is used as the
1170      column name and the dictionary key for feature parsing configs, feature
1171      `Tensor` objects, and feature columns.
1172    hash_bucket_size: An int > 1. The number of buckets.
1173    dtype: The type of features. Only string and integer types are supported.
1174
1175  Returns:
1176    A `HashedCategoricalColumn`.
1177
1178  Raises:
1179    ValueError: `hash_bucket_size` is not greater than 1.
1180    ValueError: `dtype` is neither string nor integer.
1181  """
1182  if hash_bucket_size is None:
1183    raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
1184
1185  if hash_bucket_size < 1:
1186    raise ValueError('hash_bucket_size must be at least 1. '
1187                     'hash_bucket_size: {}, key: {}'.format(
1188                         hash_bucket_size, key))
1189
1190  fc_utils.assert_key_is_string(key)
1191  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1192
1193  return HashedCategoricalColumn(key, hash_bucket_size, dtype)
1194
1195
1196@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file'])
1197def categorical_column_with_vocabulary_file(key,
1198                                            vocabulary_file,
1199                                            vocabulary_size=None,
1200                                            num_oov_buckets=0,
1201                                            default_value=None,
1202                                            dtype=dtypes.string):
1203  """A `CategoricalColumn` with a vocabulary file.
1204
1205  Use this when your inputs are in string or integer format, and you have a
1206  vocabulary file that maps each value to an integer ID. By default,
1207  out-of-vocabulary values are ignored. Use either (but not both) of
1208  `num_oov_buckets` and `default_value` to specify how to include
1209  out-of-vocabulary values.
1210
1211  For input dictionary `features`, `features[key]` is either `Tensor` or
1212  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1213  and `''` for string, which will be dropped by this feature column.
1214
1215  Example with `num_oov_buckets`:
1216  File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
1217  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1218  corresponding to its line number. All other values are hashed and assigned an
1219  ID 50-54.
1220
1221  ```python
1222  import tensorflow as tf
1223  states = tf.feature_column.categorical_column_with_vocabulary_file(
1224    key='states', vocabulary_file='states.txt', vocabulary_size=5,
1225    num_oov_buckets=1)
1226  columns = [states]
1227  features = {'states':tf.constant([['california', 'georgia', 'michigan',
1228  'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1229  'texas']])}
1230  linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1231  columns)
1232  ```
1233
1234  Example with `default_value`:
1235  File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
1236  other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
1237  in input, and other values missing from the file, will be assigned ID 0. All
1238  others are assigned the corresponding line number 1-50.
1239
1240  ```python
1241  import tensorflow as tf
1242  states = tf.feature_column.categorical_column_with_vocabulary_file(
1243    key='states', vocabulary_file='states.txt', vocabulary_size=6,
1244    default_value=0)
1245  columns = [states]
1246  features = {'states':tf.constant([['california', 'georgia', 'michigan',
1247  'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1248  'texas']])}
1249  linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1250  columns)
1251  ```
1252
1253  And to make an embedding with either:
1254
1255  ```python
1256  import tensorflow as tf
1257  states = tf.feature_column.categorical_column_with_vocabulary_file(
1258    key='states', vocabulary_file='states.txt', vocabulary_size=5,
1259    num_oov_buckets=1)
1260  columns = [tf.feature_column.embedding_column(states, 3)]
1261  features = {'states':tf.constant([['california', 'georgia', 'michigan',
1262  'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1263  'texas']])}
1264  input_layer = tf.keras.layers.DenseFeatures(columns)
1265  dense_tensor = input_layer(features)
1266  ```
1267
1268  Args:
1269    key: A unique string identifying the input feature. It is used as the
1270      column name and the dictionary key for feature parsing configs, feature
1271      `Tensor` objects, and feature columns.
1272    vocabulary_file: The vocabulary file name.
1273    vocabulary_size: Number of the elements in the vocabulary. This must be no
1274      greater than length of `vocabulary_file`, if less than length, later
1275      values are ignored. If None, it is set to the length of `vocabulary_file`.
1276    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1277      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1278      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1279      the input value. A positive `num_oov_buckets` can not be specified with
1280      `default_value`.
1281    default_value: The integer ID value to return for out-of-vocabulary feature
1282      values, defaults to `-1`. This can not be specified with a positive
1283      `num_oov_buckets`.
1284    dtype: The type of features. Only string and integer types are supported.
1285
1286  Returns:
1287    A `CategoricalColumn` with a vocabulary file.
1288
1289  Raises:
1290    ValueError: `vocabulary_file` is missing or cannot be opened.
1291    ValueError: `vocabulary_size` is missing or < 1.
1292    ValueError: `num_oov_buckets` is a negative integer.
1293    ValueError: `num_oov_buckets` and `default_value` are both specified.
1294    ValueError: `dtype` is neither string nor integer.
1295  """
1296  return categorical_column_with_vocabulary_file_v2(
1297      key, vocabulary_file, vocabulary_size,
1298      dtype, default_value,
1299      num_oov_buckets)
1300
1301
1302@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[])
1303def categorical_column_with_vocabulary_file_v2(key,
1304                                               vocabulary_file,
1305                                               vocabulary_size=None,
1306                                               dtype=dtypes.string,
1307                                               default_value=None,
1308                                               num_oov_buckets=0):
1309  """A `CategoricalColumn` with a vocabulary file.
1310
1311  Use this when your inputs are in string or integer format, and you have a
1312  vocabulary file that maps each value to an integer ID. By default,
1313  out-of-vocabulary values are ignored. Use either (but not both) of
1314  `num_oov_buckets` and `default_value` to specify how to include
1315  out-of-vocabulary values.
1316
1317  For input dictionary `features`, `features[key]` is either `Tensor` or
1318  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1319  and `''` for string, which will be dropped by this feature column.
1320
1321  Example with `num_oov_buckets`:
1322  File `'/us/states.txt'` contains 50 lines, each with a 2-character U.S. state
1323  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1324  corresponding to its line number. All other values are hashed and assigned an
1325  ID 50-54.
1326
1327  ```python
1328  states = categorical_column_with_vocabulary_file(
1329      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
1330      num_oov_buckets=5)
1331  columns = [states, ...]
1332  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1333  linear_prediction = linear_model(features, columns)
1334  ```
1335
1336  Example with `default_value`:
1337  File `'/us/states.txt'` contains 51 lines - the first line is `'XX'`, and the
1338  other 50 each have a 2-character U.S. state abbreviation. Both a literal
1339  `'XX'` in input, and other values missing from the file, will be assigned
1340  ID 0. All others are assigned the corresponding line number 1-50.
1341
1342  ```python
1343  states = categorical_column_with_vocabulary_file(
1344      key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
1345      default_value=0)
1346  columns = [states, ...]
1347  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1348  linear_prediction, _, _ = linear_model(features, columns)
1349  ```
1350
1351  And to make an embedding with either:
1352
1353  ```python
1354  columns = [embedding_column(states, 3),...]
1355  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1356  dense_tensor = input_layer(features, columns)
1357  ```
1358
1359  Args:
1360    key: A unique string identifying the input feature. It is used as the
1361      column name and the dictionary key for feature parsing configs, feature
1362      `Tensor` objects, and feature columns.
1363    vocabulary_file: The vocabulary file name.
1364    vocabulary_size: Number of the elements in the vocabulary. This must be no
1365      greater than length of `vocabulary_file`, if less than length, later
1366      values are ignored. If None, it is set to the length of `vocabulary_file`.
1367    dtype: The type of features. Only string and integer types are supported.
1368    default_value: The integer ID value to return for out-of-vocabulary feature
1369      values, defaults to `-1`. This can not be specified with a positive
1370      `num_oov_buckets`.
1371    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1372      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1373      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1374      the input value. A positive `num_oov_buckets` can not be specified with
1375      `default_value`.
1376
1377  Returns:
1378    A `CategoricalColumn` with a vocabulary file.
1379
1380  Raises:
1381    ValueError: `vocabulary_file` is missing or cannot be opened.
1382    ValueError: `vocabulary_size` is missing or < 1.
1383    ValueError: `num_oov_buckets` is a negative integer.
1384    ValueError: `num_oov_buckets` and `default_value` are both specified.
1385    ValueError: `dtype` is neither string nor integer.
1386  """
1387  if not vocabulary_file:
1388    raise ValueError('Missing vocabulary_file in {}.'.format(key))
1389
1390  if vocabulary_size is None:
1391    if not gfile.Exists(vocabulary_file):
1392      raise ValueError('vocabulary_file in {} does not exist.'.format(key))
1393
1394    with gfile.GFile(vocabulary_file, mode='rb') as f:
1395      vocabulary_size = sum(1 for _ in f)
1396    logging.info(
1397        'vocabulary_size = %d in %s is inferred from the number of elements '
1398        'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
1399
1400  # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
1401  if vocabulary_size < 1:
1402    raise ValueError('Invalid vocabulary_size in {}.'.format(key))
1403  if num_oov_buckets:
1404    if default_value is not None:
1405      raise ValueError(
1406          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1407              key))
1408    if num_oov_buckets < 0:
1409      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1410          num_oov_buckets, key))
1411  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1412  fc_utils.assert_key_is_string(key)
1413  return VocabularyFileCategoricalColumn(
1414      key=key,
1415      vocabulary_file=vocabulary_file,
1416      vocabulary_size=vocabulary_size,
1417      num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
1418      default_value=-1 if default_value is None else default_value,
1419      dtype=dtype)
1420
1421
1422@tf_export('feature_column.categorical_column_with_vocabulary_list')
1423def categorical_column_with_vocabulary_list(key,
1424                                            vocabulary_list,
1425                                            dtype=None,
1426                                            default_value=-1,
1427                                            num_oov_buckets=0):
1428  """A `CategoricalColumn` with in-memory vocabulary.
1429
1430  Use this when your inputs are in string or integer format, and you have an
1431  in-memory vocabulary mapping each value to an integer ID. By default,
1432  out-of-vocabulary values are ignored. Use either (but not both) of
1433  `num_oov_buckets` and `default_value` to specify how to include
1434  out-of-vocabulary values.
1435
1436  For input dictionary `features`, `features[key]` is either `Tensor` or
1437  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1438  and `''` for string, which will be dropped by this feature column.
1439
1440  Example with `num_oov_buckets`:
1441  In the following example, each input in `vocabulary_list` is assigned an ID
1442  0-3 corresponding to its index (e.g., input 'B' produces output 2). All other
1443  inputs are hashed and assigned an ID 4-5.
1444
1445  ```python
1446  colors = categorical_column_with_vocabulary_list(
1447      key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
1448      num_oov_buckets=2)
1449  columns = [colors, ...]
1450  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1451  linear_prediction, _, _ = linear_model(features, columns)
1452  ```
1453
1454  Example with `default_value`:
1455  In the following example, each input in `vocabulary_list` is assigned an ID
1456  0-4 corresponding to its index (e.g., input 'B' produces output 3). All other
1457  inputs are assigned `default_value` 0.
1458
1459
1460  ```python
1461  colors = categorical_column_with_vocabulary_list(
1462      key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
1463  columns = [colors, ...]
1464  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1465  linear_prediction, _, _ = linear_model(features, columns)
1466  ```
1467
1468  And to make an embedding with either:
1469
1470  ```python
1471  columns = [embedding_column(colors, 3),...]
1472  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1473  dense_tensor = input_layer(features, columns)
1474  ```
1475
1476  Args:
1477    key: A unique string identifying the input feature. It is used as the column
1478      name and the dictionary key for feature parsing configs, feature `Tensor`
1479      objects, and feature columns.
1480    vocabulary_list: An ordered iterable defining the vocabulary. Each feature
1481      is mapped to the index of its value (if present) in `vocabulary_list`.
1482      Must be castable to `dtype`.
1483    dtype: The type of features. Only string and integer types are supported. If
1484      `None`, it will be inferred from `vocabulary_list`.
1485    default_value: The integer ID value to return for out-of-vocabulary feature
1486      values, defaults to `-1`. This can not be specified with a positive
1487      `num_oov_buckets`.
1488    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1489      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1490      `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
1491      hash of the input value. A positive `num_oov_buckets` can not be specified
1492      with `default_value`.
1493
1494  Returns:
1495    A `CategoricalColumn` with in-memory vocabulary.
1496
1497  Raises:
1498    ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
1499    ValueError: `num_oov_buckets` is a negative integer.
1500    ValueError: `num_oov_buckets` and `default_value` are both specified.
1501    ValueError: if `dtype` is not integer or string.
1502  """
1503  if (vocabulary_list is None) or (len(vocabulary_list) < 1):
1504    raise ValueError(
1505        'vocabulary_list {} must be non-empty, column_name: {}'.format(
1506            vocabulary_list, key))
1507  if len(set(vocabulary_list)) != len(vocabulary_list):
1508    raise ValueError(
1509        'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
1510            vocabulary_list, key))
1511  vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
1512  if num_oov_buckets:
1513    if default_value != -1:
1514      raise ValueError(
1515          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1516              key))
1517    if num_oov_buckets < 0:
1518      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1519          num_oov_buckets, key))
1520  fc_utils.assert_string_or_int(
1521      vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
1522  if dtype is None:
1523    dtype = vocabulary_dtype
1524  elif dtype.is_integer != vocabulary_dtype.is_integer:
1525    raise ValueError(
1526        'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
1527            dtype, vocabulary_dtype, key))
1528  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1529  fc_utils.assert_key_is_string(key)
1530
1531  return VocabularyListCategoricalColumn(
1532      key=key,
1533      vocabulary_list=tuple(vocabulary_list),
1534      dtype=dtype,
1535      default_value=default_value,
1536      num_oov_buckets=num_oov_buckets)
1537
1538
1539@tf_export('feature_column.categorical_column_with_identity')
1540def categorical_column_with_identity(key, num_buckets, default_value=None):
1541  """A `CategoricalColumn` that returns identity values.
1542
1543  Use this when your inputs are integers in the range `[0, num_buckets)`, and
1544  you want to use the input value itself as the categorical ID. Values outside
1545  this range will result in `default_value` if specified, otherwise it will
1546  fail.
1547
1548  Typically, this is used for contiguous ranges of integer indexes, but
1549  it doesn't have to be. This might be inefficient, however, if many of IDs
1550  are unused. Consider `categorical_column_with_hash_bucket` in that case.
1551
1552  For input dictionary `features`, `features[key]` is either `Tensor` or
1553  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1554  and `''` for string, which will be dropped by this feature column.
1555
1556  In the following examples, each input in the range `[0, 1000000)` is assigned
1557  the same value. All other inputs are assigned `default_value` 0. Note that a
1558  literal 0 in inputs will result in the same default ID.
1559
1560  Linear model:
1561
1562  ```python
1563  import tensorflow as tf
1564  video_id = tf.feature_column.categorical_column_with_identity(
1565      key='video_id', num_buckets=1000000, default_value=0)
1566  columns = [video_id]
1567  features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0],
1568  [33,78, 2, 73, 1]])}
1569  linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1570  columns)
1571  ```
1572
1573  Embedding for a DNN model:
1574
1575  ```python
1576  import tensorflow as tf
1577  video_id = tf.feature_column.categorical_column_with_identity(
1578      key='video_id', num_buckets=1000000, default_value=0)
1579  columns = [tf.feature_column.embedding_column(video_id, 9)]
1580  features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0],
1581  [33,78, 2, 73, 1]])}
1582  input_layer = tf.keras.layers.DenseFeatures(columns)
1583  dense_tensor = input_layer(features)
1584  ```
1585
1586  Args:
1587    key: A unique string identifying the input feature. It is used as the
1588      column name and the dictionary key for feature parsing configs, feature
1589      `Tensor` objects, and feature columns.
1590    num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
1591    default_value: If set, values outside of range `[0, num_buckets)` will
1592      be replaced with this value. If not set, values >= num_buckets will
1593      cause a failure while values < 0 will be dropped.
1594
1595  Returns:
1596    A `CategoricalColumn` that returns identity values.
1597
1598  Raises:
1599    ValueError: if `num_buckets` is less than one.
1600    ValueError: if `default_value` is not in range `[0, num_buckets)`.
1601  """
1602  if num_buckets < 1:
1603    raise ValueError(
1604        'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
1605  if (default_value is not None) and (
1606      (default_value < 0) or (default_value >= num_buckets)):
1607    raise ValueError(
1608        'default_value {} not in range [0, {}), column_name {}'.format(
1609            default_value, num_buckets, key))
1610  fc_utils.assert_key_is_string(key)
1611  return IdentityCategoricalColumn(
1612      key=key, number_buckets=num_buckets, default_value=default_value)
1613
1614
1615@tf_export('feature_column.indicator_column')
1616def indicator_column(categorical_column):
1617  """Represents multi-hot representation of given categorical column.
1618
1619  - For DNN model, `indicator_column` can be used to wrap any
1620    `categorical_column_*` (e.g., to feed to DNN). Consider to Use
1621    `embedding_column` if the number of buckets/unique(values) are large.
1622
1623  - For Wide (aka linear) model, `indicator_column` is the internal
1624    representation for categorical column when passing categorical column
1625    directly (as any element in feature_columns) to `linear_model`. See
1626    `linear_model` for details.
1627
1628  ```python
1629  name = indicator_column(categorical_column_with_vocabulary_list(
1630      'name', ['bob', 'george', 'wanda']))
1631  columns = [name, ...]
1632  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1633  dense_tensor = input_layer(features, columns)
1634
1635  dense_tensor == [[1, 0, 0]]  # If "name" bytes_list is ["bob"]
1636  dense_tensor == [[1, 0, 1]]  # If "name" bytes_list is ["bob", "wanda"]
1637  dense_tensor == [[2, 0, 0]]  # If "name" bytes_list is ["bob", "bob"]
1638  ```
1639
1640  Args:
1641    categorical_column: A `CategoricalColumn` which is created by
1642      `categorical_column_with_*` or `crossed_column` functions.
1643
1644  Returns:
1645    An `IndicatorColumn`.
1646
1647  Raises:
1648    ValueError: If `categorical_column` is not CategoricalColumn type.
1649  """
1650  if not isinstance(categorical_column,
1651                    (CategoricalColumn, fc_old._CategoricalColumn)):  # pylint: disable=protected-access
1652    raise ValueError(
1653        'Unsupported input type. Input must be a CategoricalColumn. '
1654        'Given: {}'.format(categorical_column))
1655  return IndicatorColumn(categorical_column)
1656
1657
1658@tf_export('feature_column.weighted_categorical_column')
1659def weighted_categorical_column(categorical_column,
1660                                weight_feature_key,
1661                                dtype=dtypes.float32):
1662  """Applies weight values to a `CategoricalColumn`.
1663
1664  Use this when each of your sparse inputs has both an ID and a value. For
1665  example, if you're representing text documents as a collection of word
1666  frequencies, you can provide 2 parallel sparse input features ('terms' and
1667  'frequencies' below).
1668
1669  Example:
1670
1671  Input `tf.Example` objects:
1672
1673  ```proto
1674  [
1675    features {
1676      feature {
1677        key: "terms"
1678        value {bytes_list {value: "very" value: "model"}}
1679      }
1680      feature {
1681        key: "frequencies"
1682        value {float_list {value: 0.3 value: 0.1}}
1683      }
1684    },
1685    features {
1686      feature {
1687        key: "terms"
1688        value {bytes_list {value: "when" value: "course" value: "human"}}
1689      }
1690      feature {
1691        key: "frequencies"
1692        value {float_list {value: 0.4 value: 0.1 value: 0.2}}
1693      }
1694    }
1695  ]
1696  ```
1697
1698  ```python
1699  categorical_column = categorical_column_with_hash_bucket(
1700      column_name='terms', hash_bucket_size=1000)
1701  weighted_column = weighted_categorical_column(
1702      categorical_column=categorical_column, weight_feature_key='frequencies')
1703  columns = [weighted_column, ...]
1704  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1705  linear_prediction, _, _ = linear_model(features, columns)
1706  ```
1707
1708  This assumes the input dictionary contains a `SparseTensor` for key
1709  'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have
1710  the same indices and dense shape.
1711
1712  Args:
1713    categorical_column: A `CategoricalColumn` created by
1714      `categorical_column_with_*` functions.
1715    weight_feature_key: String key for weight values.
1716    dtype: Type of weights, such as `tf.float32`. Only float and integer weights
1717      are supported.
1718
1719  Returns:
1720    A `CategoricalColumn` composed of two sparse features: one represents id,
1721    the other represents weight (value) of the id feature in that example.
1722
1723  Raises:
1724    ValueError: if `dtype` is not convertible to float.
1725  """
1726  if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
1727    raise ValueError('dtype {} is not convertible to float.'.format(dtype))
1728  return WeightedCategoricalColumn(
1729      categorical_column=categorical_column,
1730      weight_feature_key=weight_feature_key,
1731      dtype=dtype)
1732
1733
1734@tf_export('feature_column.crossed_column')
1735def crossed_column(keys, hash_bucket_size, hash_key=None):
1736  """Returns a column for performing crosses of categorical features.
1737
1738  Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
1739  the transformation can be thought of as:
1740    Hash(cartesian product of features) % `hash_bucket_size`
1741
1742  For example, if the input features are:
1743
1744  * SparseTensor referred by first key:
1745
1746    ```python
1747    shape = [2, 2]
1748    {
1749        [0, 0]: "a"
1750        [1, 0]: "b"
1751        [1, 1]: "c"
1752    }
1753    ```
1754
1755  * SparseTensor referred by second key:
1756
1757    ```python
1758    shape = [2, 1]
1759    {
1760        [0, 0]: "d"
1761        [1, 0]: "e"
1762    }
1763    ```
1764
1765  then crossed feature will look like:
1766
1767  ```python
1768   shape = [2, 2]
1769  {
1770      [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
1771      [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
1772      [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
1773  }
1774  ```
1775
1776  Here is an example to create a linear model with crosses of string features:
1777
1778  ```python
1779  keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
1780  columns = [keywords_x_doc_terms, ...]
1781  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1782  linear_prediction = linear_model(features, columns)
1783  ```
1784
1785  You could also use vocabulary lookup before crossing:
1786
1787  ```python
1788  keywords = categorical_column_with_vocabulary_file(
1789      'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
1790  keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
1791  columns = [keywords_x_doc_terms, ...]
1792  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1793  linear_prediction = linear_model(features, columns)
1794  ```
1795
1796  If an input feature is of numeric type, you can use
1797  `categorical_column_with_identity`, or `bucketized_column`, as in the example:
1798
1799  ```python
1800  # vertical_id is an integer categorical feature.
1801  vertical_id = categorical_column_with_identity('vertical_id', 10K)
1802  price = numeric_column('price')
1803  # bucketized_column converts numerical feature to a categorical one.
1804  bucketized_price = bucketized_column(price, boundaries=[...])
1805  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1806  columns = [vertical_id_x_price, ...]
1807  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1808  linear_prediction = linear_model(features, columns)
1809  ```
1810
1811  To use crossed column in DNN model, you need to add it in an embedding column
1812  as in this example:
1813
1814  ```python
1815  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1816  vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
1817  dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...])
1818  ```
1819
1820  Args:
1821    keys: An iterable identifying the features to be crossed. Each element can
1822      be either:
1823      * string: Will use the corresponding feature which must be of string type.
1824      * `CategoricalColumn`: Will use the transformed tensor produced by this
1825        column. Does not support hashed categorical column.
1826    hash_bucket_size: An int > 1. The number of buckets.
1827    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
1828      function to combine the crosses fingerprints on SparseCrossOp (optional).
1829
1830  Returns:
1831    A `CrossedColumn`.
1832
1833  Raises:
1834    ValueError: If `len(keys) < 2`.
1835    ValueError: If any of the keys is neither a string nor `CategoricalColumn`.
1836    ValueError: If any of the keys is `HashedCategoricalColumn`.
1837    ValueError: If `hash_bucket_size < 1`.
1838  """
1839  if not hash_bucket_size or hash_bucket_size < 1:
1840    raise ValueError('hash_bucket_size must be > 1. '
1841                     'hash_bucket_size: {}'.format(hash_bucket_size))
1842  if not keys or len(keys) < 2:
1843    raise ValueError(
1844        'keys must be a list with length > 1. Given: {}'.format(keys))
1845  for key in keys:
1846    if (not isinstance(key, six.string_types) and
1847        not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))):  # pylint: disable=protected-access
1848      raise ValueError(
1849          'Unsupported key type. All keys must be either string, or '
1850          'categorical column except HashedCategoricalColumn. '
1851          'Given: {}'.format(key))
1852    if isinstance(key,
1853                  (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)):  # pylint: disable=protected-access
1854      raise ValueError(
1855          'categorical_column_with_hash_bucket is not supported for crossing. '
1856          'Hashing before crossing will increase probability of collision. '
1857          'Instead, use the feature name as a string. Given: {}'.format(key))
1858  return CrossedColumn(
1859      keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
1860
1861
1862@six.add_metaclass(abc.ABCMeta)
1863class FeatureColumn(object):
1864  """Represents a feature column abstraction.
1865
1866  WARNING: Do not subclass this layer unless you know what you are doing:
1867  the API is subject to future changes.
1868
1869  To distinguish between the concept of a feature family and a specific binary
1870  feature within a family, we refer to a feature family like "country" as a
1871  feature column. For example, we can have a feature in a `tf.Example` format:
1872    {key: "country",  value: [ "US" ]}
1873  In this example the value of feature is "US" and "country" refers to the
1874  column of the feature.
1875
1876  This class is an abstract class. Users should not create instances of this.
1877  """
1878
1879  @abc.abstractproperty
1880  def name(self):
1881    """Returns string. Used for naming."""
1882    pass
1883
1884  def __lt__(self, other):
1885    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
1886
1887    Feature columns need to occasionally be sortable, for example when used as
1888    keys in a features dictionary passed to a layer.
1889
1890    In CPython, `__lt__` must be defined for all objects in the
1891    sequence being sorted.
1892
1893    If any objects in the sequence being sorted do not have an `__lt__` method
1894    compatible with feature column objects (such as strings), then CPython will
1895    fall back to using the `__gt__` method below.
1896    https://docs.python.org/3/library/stdtypes.html#list.sort
1897
1898    Args:
1899      other: The other object to compare to.
1900
1901    Returns:
1902      True if the string representation of this object is lexicographically less
1903      than the string representation of `other`. For FeatureColumn objects,
1904      this looks like "<__main__.FeatureColumn object at 0xa>".
1905    """
1906    return str(self) < str(other)
1907
1908  def __gt__(self, other):
1909    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
1910
1911    Feature columns need to occasionally be sortable, for example when used as
1912    keys in a features dictionary passed to a layer.
1913
1914    `__gt__` is called when the "other" object being compared during the sort
1915    does not have `__lt__` defined.
1916    Example:
1917    ```
1918    # __lt__ only class
1919    class A():
1920      def __lt__(self, other): return str(self) < str(other)
1921
1922    a = A()
1923    a < "b" # True
1924    "0" < a # Error
1925
1926    # __lt__ and __gt__ class
1927    class B():
1928      def __lt__(self, other): return str(self) < str(other)
1929      def __gt__(self, other): return str(self) > str(other)
1930
1931    b = B()
1932    b < "c" # True
1933    "0" < b # True
1934    ```
1935
1936    Args:
1937      other: The other object to compare to.
1938
1939    Returns:
1940      True if the string representation of this object is lexicographically
1941      greater than the string representation of `other`. For FeatureColumn
1942      objects, this looks like "<__main__.FeatureColumn object at 0xa>".
1943    """
1944    return str(self) > str(other)
1945
1946  @abc.abstractmethod
1947  def transform_feature(self, transformation_cache, state_manager):
1948    """Returns intermediate representation (usually a `Tensor`).
1949
1950    Uses `transformation_cache` to create an intermediate representation
1951    (usually a `Tensor`) that other feature columns can use.
1952
1953    Example usage of `transformation_cache`:
1954    Let's say a Feature column depends on raw feature ('raw') and another
1955    `FeatureColumn` (input_fc). To access corresponding `Tensor`s,
1956    transformation_cache will be used as follows:
1957
1958    ```python
1959    raw_tensor = transformation_cache.get('raw', state_manager)
1960    fc_tensor = transformation_cache.get(input_fc, state_manager)
1961    ```
1962
1963    Args:
1964      transformation_cache: A `FeatureTransformationCache` object to access
1965        features.
1966      state_manager: A `StateManager` to create / access resources such as
1967        lookup tables.
1968
1969    Returns:
1970      Transformed feature `Tensor`.
1971    """
1972    pass
1973
1974  @abc.abstractproperty
1975  def parse_example_spec(self):
1976    """Returns a `tf.Example` parsing spec as dict.
1977
1978    It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is
1979    a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
1980    supported objects. Please check documentation of `tf.io.parse_example` for
1981    all supported spec objects.
1982
1983    Let's say a Feature column depends on raw feature ('raw') and another
1984    `FeatureColumn` (input_fc). One possible implementation of
1985    parse_example_spec is as follows:
1986
1987    ```python
1988    spec = {'raw': tf.io.FixedLenFeature(...)}
1989    spec.update(input_fc.parse_example_spec)
1990    return spec
1991    ```
1992    """
1993    pass
1994
1995  def create_state(self, state_manager):
1996    """Uses the `state_manager` to create state for the FeatureColumn.
1997
1998    Args:
1999      state_manager: A `StateManager` to create / access resources such as
2000        lookup tables and variables.
2001    """
2002    pass
2003
2004  @abc.abstractproperty
2005  def _is_v2_column(self):
2006    """Returns whether this FeatureColumn is fully conformant to the new API.
2007
2008    This is needed for composition type cases where an EmbeddingColumn etc.
2009    might take in old categorical columns as input and then we want to use the
2010    old API.
2011    """
2012    pass
2013
2014  @abc.abstractproperty
2015  def parents(self):
2016    """Returns a list of immediate raw feature and FeatureColumn dependencies.
2017
2018    For example:
2019    # For the following feature columns
2020    a = numeric_column('f1')
2021    c = crossed_column(a, 'f2')
2022    # The expected parents are:
2023    a.parents = ['f1']
2024    c.parents = [a, 'f2']
2025    """
2026    pass
2027
2028  def get_config(self):
2029    """Returns the config of the feature column.
2030
2031    A FeatureColumn config is a Python dictionary (serializable) containing the
2032    configuration of a FeatureColumn. The same FeatureColumn can be
2033    reinstantiated later from this configuration.
2034
2035    The config of a feature column does not include information about feature
2036    columns depending on it nor the FeatureColumn class name.
2037
2038    Example with (de)serialization practices followed in this file:
2039    ```python
2040    class SerializationExampleFeatureColumn(
2041        FeatureColumn, collections.namedtuple(
2042            'SerializationExampleFeatureColumn',
2043            ('dimension', 'parent', 'dtype', 'normalizer_fn'))):
2044
2045      def get_config(self):
2046        # Create a dict from the namedtuple.
2047        # Python attribute literals can be directly copied from / to the config.
2048        # For example 'dimension', assuming it is an integer literal.
2049        config = dict(zip(self._fields, self))
2050
2051        # (De)serialization of parent FeatureColumns should use the provided
2052        # (de)serialize_feature_column() methods that take care of de-duping.
2053        config['parent'] = serialize_feature_column(self.parent)
2054
2055        # Many objects provide custom (de)serialization e.g: for tf.DType
2056        # tf.DType.name, tf.as_dtype() can be used.
2057        config['dtype'] = self.dtype.name
2058
2059        # Non-trivial dependencies should be Keras-(de)serializable.
2060        config['normalizer_fn'] = generic_utils.serialize_keras_object(
2061            self.normalizer_fn)
2062
2063        return config
2064
2065      @classmethod
2066      def from_config(cls, config, custom_objects=None, columns_by_name=None):
2067        # This should do the inverse transform from `get_config` and construct
2068        # the namedtuple.
2069        kwargs = config.copy()
2070        kwargs['parent'] = deserialize_feature_column(
2071            config['parent'], custom_objects, columns_by_name)
2072        kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2073        kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
2074          config['normalizer_fn'], custom_objects=custom_objects)
2075        return cls(**kwargs)
2076
2077    ```
2078    Returns:
2079      A serializable Dict that can be used to deserialize the object with
2080      from_config.
2081    """
2082    return self._get_config()
2083
2084  def _get_config(self):
2085    raise NotImplementedError('Must be implemented in subclasses.')
2086
2087  @classmethod
2088  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2089    """Creates a FeatureColumn from its config.
2090
2091    This method should be the reverse of `get_config`, capable of instantiating
2092    the same FeatureColumn from the config dictionary. See `get_config` for an
2093    example of common (de)serialization practices followed in this file.
2094
2095    TODO(b/118939620): This is a private method until consensus is reached on
2096    supporting object deserialization deduping within Keras.
2097
2098    Args:
2099      config: A Dict config acquired with `get_config`.
2100      custom_objects: Optional dictionary mapping names (strings) to custom
2101        classes or functions to be considered during deserialization.
2102      columns_by_name: A Dict[String, FeatureColumn] of existing columns in
2103        order to avoid duplication. Should be passed to any calls to
2104        deserialize_feature_column().
2105
2106    Returns:
2107      A FeatureColumn for the input config.
2108    """
2109    return cls._from_config(config, custom_objects, columns_by_name)
2110
2111  @classmethod
2112  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
2113    raise NotImplementedError('Must be implemented in subclasses.')
2114
2115
2116class DenseColumn(FeatureColumn):
2117  """Represents a column which can be represented as `Tensor`.
2118
2119  Some examples of this type are: numeric_column, embedding_column,
2120  indicator_column.
2121  """
2122
2123  @abc.abstractproperty
2124  def variable_shape(self):
2125    """`TensorShape` of `get_dense_tensor`, without batch dimension."""
2126    pass
2127
2128  @abc.abstractmethod
2129  def get_dense_tensor(self, transformation_cache, state_manager):
2130    """Returns a `Tensor`.
2131
2132    The output of this function will be used by model-builder-functions. For
2133    example the pseudo code of `input_layer` will be like:
2134
2135    ```python
2136    def input_layer(features, feature_columns, ...):
2137      outputs = [fc.get_dense_tensor(...) for fc in feature_columns]
2138      return tf.concat(outputs)
2139    ```
2140
2141    Args:
2142      transformation_cache: A `FeatureTransformationCache` object to access
2143        features.
2144      state_manager: A `StateManager` to create / access resources such as
2145        lookup tables.
2146
2147    Returns:
2148      `Tensor` of shape [batch_size] + `variable_shape`.
2149    """
2150    pass
2151
2152
2153def is_feature_column_v2(feature_columns):
2154  """Returns True if all feature columns are V2."""
2155  for feature_column in feature_columns:
2156    if not isinstance(feature_column, FeatureColumn):
2157      return False
2158    if not feature_column._is_v2_column:  # pylint: disable=protected-access
2159      return False
2160  return True
2161
2162
2163def _create_weighted_sum(column, transformation_cache, state_manager,
2164                         sparse_combiner, weight_var):
2165  """Creates a weighted sum for a dense/categorical column for linear_model."""
2166  if isinstance(column, CategoricalColumn):
2167    return _create_categorical_column_weighted_sum(
2168        column=column,
2169        transformation_cache=transformation_cache,
2170        state_manager=state_manager,
2171        sparse_combiner=sparse_combiner,
2172        weight_var=weight_var)
2173  else:
2174    return _create_dense_column_weighted_sum(
2175        column=column,
2176        transformation_cache=transformation_cache,
2177        state_manager=state_manager,
2178        weight_var=weight_var)
2179
2180
2181def _create_dense_column_weighted_sum(column, transformation_cache,
2182                                      state_manager, weight_var):
2183  """Create a weighted sum of a dense column for linear_model."""
2184  tensor = column.get_dense_tensor(transformation_cache, state_manager)
2185  num_elements = column.variable_shape.num_elements()
2186  batch_size = array_ops.shape(tensor)[0]
2187  tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
2188  return math_ops.matmul(tensor, weight_var, name='weighted_sum')
2189
2190
2191class CategoricalColumn(FeatureColumn):
2192  """Represents a categorical feature.
2193
2194  A categorical feature typically handled with a `tf.sparse.SparseTensor` of
2195  IDs.
2196  """
2197
2198  IdWeightPair = collections.namedtuple(  # pylint: disable=invalid-name
2199      'IdWeightPair', ('id_tensor', 'weight_tensor'))
2200
2201  @abc.abstractproperty
2202  def num_buckets(self):
2203    """Returns number of buckets in this sparse feature."""
2204    pass
2205
2206  @abc.abstractmethod
2207  def get_sparse_tensors(self, transformation_cache, state_manager):
2208    """Returns an IdWeightPair.
2209
2210    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
2211    weights.
2212
2213    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
2214    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
2215    `SparseTensor` of `float` or `None` to indicate all weights should be
2216    taken to be 1. If specified, `weight_tensor` must have exactly the same
2217    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
2218    output of a `VarLenFeature` which is a ragged matrix.
2219
2220    Args:
2221      transformation_cache: A `FeatureTransformationCache` object to access
2222        features.
2223      state_manager: A `StateManager` to create / access resources such as
2224        lookup tables.
2225    """
2226    pass
2227
2228
2229def _create_categorical_column_weighted_sum(
2230    column, transformation_cache, state_manager, sparse_combiner, weight_var):
2231  # pylint: disable=g-doc-return-or-yield,g-doc-args
2232  """Create a weighted sum of a categorical column for linear_model.
2233
2234  Note to maintainer: As implementation details, the weighted sum is
2235  implemented via embedding_lookup_sparse toward efficiency. Mathematically,
2236  they are the same.
2237
2238  To be specific, conceptually, categorical column can be treated as multi-hot
2239  vector. Say:
2240
2241  ```python
2242    x = [0 0 1]  # categorical column input
2243    w = [a b c]  # weights
2244  ```
2245  The weighted sum is `c` in this case, which is same as `w[2]`.
2246
2247  Another example is
2248
2249  ```python
2250    x = [0 1 1]  # categorical column input
2251    w = [a b c]  # weights
2252  ```
2253  The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
2254
2255  For both cases, we can implement weighted sum via embedding_lookup with
2256  sparse_combiner = "sum".
2257  """
2258
2259  sparse_tensors = column.get_sparse_tensors(transformation_cache,
2260                                             state_manager)
2261  id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [
2262      array_ops.shape(sparse_tensors.id_tensor)[0], -1
2263  ])
2264  weight_tensor = sparse_tensors.weight_tensor
2265  if weight_tensor is not None:
2266    weight_tensor = sparse_ops.sparse_reshape(
2267        weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
2268
2269  return embedding_ops.safe_embedding_lookup_sparse(
2270      weight_var,
2271      id_tensor,
2272      sparse_weights=weight_tensor,
2273      combiner=sparse_combiner,
2274      name='weighted_sum')
2275
2276
2277class SequenceDenseColumn(FeatureColumn):
2278  """Represents dense sequence data."""
2279
2280  TensorSequenceLengthPair = collections.namedtuple(  # pylint: disable=invalid-name
2281      'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length'))
2282
2283  @abc.abstractmethod
2284  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
2285    """Returns a `TensorSequenceLengthPair`.
2286
2287    Args:
2288      transformation_cache: A `FeatureTransformationCache` object to access
2289        features.
2290      state_manager: A `StateManager` to create / access resources such as
2291        lookup tables.
2292    """
2293    pass
2294
2295
2296class FeatureTransformationCache(object):
2297  """Handles caching of transformations while building the model.
2298
2299  `FeatureColumn` specifies how to digest an input column to the network. Some
2300  feature columns require data transformations. This class caches those
2301  transformations.
2302
2303  Some features may be used in more than one place. For example, one can use a
2304  bucketized feature by itself and a cross with it. In that case we
2305  should create only one bucketization op instead of creating ops for each
2306  feature column separately. To handle re-use of transformed columns,
2307  `FeatureTransformationCache` caches all previously transformed columns.
2308
2309  Example:
2310  We're trying to use the following `FeatureColumn`s:
2311
2312  ```python
2313  bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
2314  keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
2315  age_X_keywords = fc.crossed_column([bucketized_age, "keywords"])
2316  ... = linear_model(features,
2317                          [bucketized_age, keywords, age_X_keywords]
2318  ```
2319
2320  If we transform each column independently, then we'll get duplication of
2321  bucketization (one for cross, one for bucketization itself).
2322  The `FeatureTransformationCache` eliminates this duplication.
2323  """
2324
2325  def __init__(self, features):
2326    """Creates a `FeatureTransformationCache`.
2327
2328    Args:
2329      features: A mapping from feature column to objects that are `Tensor` or
2330        `SparseTensor`, or can be converted to same via
2331        `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
2332        signifies a base feature (not-transformed). A `FeatureColumn` key
2333        means that this `Tensor` is the output of an existing `FeatureColumn`
2334        which can be reused.
2335    """
2336    self._features = features.copy()
2337    self._feature_tensors = {}
2338
2339  def get(self, key, state_manager, training=None):
2340    """Returns a `Tensor` for the given key.
2341
2342    A `str` key is used to access a base feature (not-transformed). When a
2343    `FeatureColumn` is passed, the transformed feature is returned if it
2344    already exists, otherwise the given `FeatureColumn` is asked to provide its
2345    transformed output, which is then cached.
2346
2347    Args:
2348      key: a `str` or a `FeatureColumn`.
2349      state_manager: A StateManager object that holds the FeatureColumn state.
2350      training: Boolean indicating whether to the column is being used in
2351        training mode. This argument is passed to the transform_feature method
2352        of any `FeatureColumn` that takes a `training` argument. For example, if
2353        a `FeatureColumn` performed dropout, it could expose a `training`
2354        argument to control whether the dropout should be applied.
2355
2356    Returns:
2357      The transformed `Tensor` corresponding to the `key`.
2358
2359    Raises:
2360      ValueError: if key is not found or a transformed `Tensor` cannot be
2361        computed.
2362    """
2363    if key in self._feature_tensors:
2364      # FeatureColumn is already transformed or converted.
2365      return self._feature_tensors[key]
2366
2367    if key in self._features:
2368      feature_tensor = self._get_raw_feature_as_tensor(key)
2369      self._feature_tensors[key] = feature_tensor
2370      return feature_tensor
2371
2372    if isinstance(key, six.string_types):
2373      raise ValueError('Feature {} is not in features dictionary.'.format(key))
2374
2375    if not isinstance(key, FeatureColumn):
2376      raise TypeError('"key" must be either a "str" or "FeatureColumn". '
2377                      'Provided: {}'.format(key))
2378
2379    column = key
2380    logging.debug('Transforming feature_column %s.', column)
2381
2382    # Some columns may need information about whether the transformation is
2383    # happening in training or prediction mode, but not all columns expose this
2384    # argument.
2385    try:
2386      transformed = column.transform_feature(
2387          self, state_manager, training=training)
2388    except TypeError:
2389      transformed = column.transform_feature(self, state_manager)
2390    if transformed is None:
2391      raise ValueError('Column {} is not supported.'.format(column.name))
2392    self._feature_tensors[column] = transformed
2393    return transformed
2394
2395  def _get_raw_feature_as_tensor(self, key):
2396    """Gets the raw_feature (keyed by `key`) as `tensor`.
2397
2398    The raw feature is converted to (sparse) tensor and maybe expand dim.
2399
2400    For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
2401    the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
2402    error out as it is not supported.
2403
2404    Args:
2405      key: A `str` key to access the raw feature.
2406
2407    Returns:
2408      A `Tensor` or `SparseTensor`.
2409
2410    Raises:
2411      ValueError: if the raw feature has rank 0.
2412    """
2413    raw_feature = self._features[key]
2414    feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2415        raw_feature)
2416
2417    def expand_dims(input_tensor):
2418      # Input_tensor must have rank 1.
2419      if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2420        return sparse_ops.sparse_reshape(
2421            input_tensor, [array_ops.shape(input_tensor)[0], 1])
2422      else:
2423        return array_ops.expand_dims(input_tensor, -1)
2424
2425    rank = feature_tensor.get_shape().ndims
2426    if rank is not None:
2427      if rank == 0:
2428        raise ValueError(
2429            'Feature (key: {}) cannot have rank 0. Given: {}'.format(
2430                key, feature_tensor))
2431      return feature_tensor if rank != 1 else expand_dims(feature_tensor)
2432
2433    # Handle dynamic rank.
2434    with ops.control_dependencies([
2435        check_ops.assert_positive(
2436            array_ops.rank(feature_tensor),
2437            message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
2438                key, feature_tensor))]):
2439      return control_flow_ops.cond(
2440          math_ops.equal(1, array_ops.rank(feature_tensor)),
2441          lambda: expand_dims(feature_tensor),
2442          lambda: feature_tensor)
2443
2444
2445# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
2446def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
2447  """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
2448
2449  If `input_tensor` is already a `SparseTensor`, just return it.
2450
2451  Args:
2452    input_tensor: A string or integer `Tensor`.
2453    ignore_value: Entries in `dense_tensor` equal to this value will be
2454      absent from the resulting `SparseTensor`. If `None`, default value of
2455      `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
2456
2457  Returns:
2458    A `SparseTensor` with the same shape as `input_tensor`.
2459
2460  Raises:
2461    ValueError: when `input_tensor`'s rank is `None`.
2462  """
2463  input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2464      input_tensor)
2465  if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2466    return input_tensor
2467  with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
2468    if ignore_value is None:
2469      if input_tensor.dtype == dtypes.string:
2470        # Exception due to TF strings are converted to numpy objects by default.
2471        ignore_value = ''
2472      elif input_tensor.dtype.is_integer:
2473        ignore_value = -1  # -1 has a special meaning of missing feature
2474      else:
2475        # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
2476        # constructing a new numpy object of the given type, which yields the
2477        # default value for that type.
2478        ignore_value = input_tensor.dtype.as_numpy_dtype()
2479    ignore_value = math_ops.cast(
2480        ignore_value, input_tensor.dtype, name='ignore_value')
2481    indices = array_ops.where_v2(
2482        math_ops.not_equal(input_tensor, ignore_value), name='indices')
2483    return sparse_tensor_lib.SparseTensor(
2484        indices=indices,
2485        values=array_ops.gather_nd(input_tensor, indices, name='values'),
2486        dense_shape=array_ops.shape(
2487            input_tensor, out_type=dtypes.int64, name='dense_shape'))
2488
2489
2490def _normalize_feature_columns(feature_columns):
2491  """Normalizes the `feature_columns` input.
2492
2493  This method converts the `feature_columns` to list type as best as it can. In
2494  addition, verifies the type and other parts of feature_columns, required by
2495  downstream library.
2496
2497  Args:
2498    feature_columns: The raw feature columns, usually passed by users.
2499
2500  Returns:
2501    The normalized feature column list.
2502
2503  Raises:
2504    ValueError: for any invalid inputs, such as empty, duplicated names, etc.
2505  """
2506  if isinstance(feature_columns, FeatureColumn):
2507    feature_columns = [feature_columns]
2508
2509  if isinstance(feature_columns, collections_abc.Iterator):
2510    feature_columns = list(feature_columns)
2511
2512  if isinstance(feature_columns, dict):
2513    raise ValueError('Expected feature_columns to be iterable, found dict.')
2514
2515  for column in feature_columns:
2516    if not isinstance(column, FeatureColumn):
2517      raise ValueError('Items of feature_columns must be a FeatureColumn. '
2518                       'Given (type {}): {}.'.format(type(column), column))
2519  if not feature_columns:
2520    raise ValueError('feature_columns must not be empty.')
2521  name_to_column = {}
2522  for column in feature_columns:
2523    if column.name in name_to_column:
2524      raise ValueError('Duplicate feature column name found for columns: {} '
2525                       'and {}. This usually means that these columns refer to '
2526                       'same base feature. Either one must be discarded or a '
2527                       'duplicated but renamed item must be inserted in '
2528                       'features dict.'.format(column,
2529                                               name_to_column[column.name]))
2530    name_to_column[column.name] = column
2531
2532  return sorted(feature_columns, key=lambda x: x.name)
2533
2534
2535class NumericColumn(
2536    DenseColumn,
2537    fc_old._DenseColumn,  # pylint: disable=protected-access
2538    collections.namedtuple(
2539        'NumericColumn',
2540        ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
2541  """see `numeric_column`."""
2542
2543  @property
2544  def _is_v2_column(self):
2545    return True
2546
2547  @property
2548  def name(self):
2549    """See `FeatureColumn` base class."""
2550    return self.key
2551
2552  @property
2553  def parse_example_spec(self):
2554    """See `FeatureColumn` base class."""
2555    return {
2556        self.key:
2557            parsing_ops.FixedLenFeature(self.shape, self.dtype,
2558                                        self.default_value)
2559    }
2560
2561  @property
2562  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2563                          _FEATURE_COLUMN_DEPRECATION)
2564  def _parse_example_spec(self):
2565    return self.parse_example_spec
2566
2567  def _transform_input_tensor(self, input_tensor):
2568    if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2569      raise ValueError(
2570          'The corresponding Tensor of numerical column must be a Tensor. '
2571          'SparseTensor is not supported. key: {}'.format(self.key))
2572    if self.normalizer_fn is not None:
2573      input_tensor = self.normalizer_fn(input_tensor)
2574    return math_ops.cast(input_tensor, dtypes.float32)
2575
2576  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2577                          _FEATURE_COLUMN_DEPRECATION)
2578  def _transform_feature(self, inputs):
2579    input_tensor = inputs.get(self.key)
2580    return self._transform_input_tensor(input_tensor)
2581
2582  def transform_feature(self, transformation_cache, state_manager):
2583    """See `FeatureColumn` base class.
2584
2585    In this case, we apply the `normalizer_fn` to the input tensor.
2586
2587    Args:
2588      transformation_cache: A `FeatureTransformationCache` object to access
2589        features.
2590      state_manager: A `StateManager` to create / access resources such as
2591        lookup tables.
2592
2593    Returns:
2594      Normalized input tensor.
2595    Raises:
2596      ValueError: If a SparseTensor is passed in.
2597    """
2598    input_tensor = transformation_cache.get(self.key, state_manager)
2599    return self._transform_input_tensor(input_tensor)
2600
2601  @property
2602  def variable_shape(self):
2603    """See `DenseColumn` base class."""
2604    return tensor_shape.TensorShape(self.shape)
2605
2606  @property
2607  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2608                          _FEATURE_COLUMN_DEPRECATION)
2609  def _variable_shape(self):
2610    return self.variable_shape
2611
2612  def get_dense_tensor(self, transformation_cache, state_manager):
2613    """Returns dense `Tensor` representing numeric feature.
2614
2615    Args:
2616      transformation_cache: A `FeatureTransformationCache` object to access
2617        features.
2618      state_manager: A `StateManager` to create / access resources such as
2619        lookup tables.
2620
2621    Returns:
2622      Dense `Tensor` created within `transform_feature`.
2623    """
2624    # Feature has been already transformed. Return the intermediate
2625    # representation created by _transform_feature.
2626    return transformation_cache.get(self, state_manager)
2627
2628  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2629                          _FEATURE_COLUMN_DEPRECATION)
2630  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2631    del weight_collections
2632    del trainable
2633    return inputs.get(self)
2634
2635  @property
2636  def parents(self):
2637    """See 'FeatureColumn` base class."""
2638    return [self.key]
2639
2640  def get_config(self):
2641    """See 'FeatureColumn` base class."""
2642    config = dict(zip(self._fields, self))
2643    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
2644    config['normalizer_fn'] = serialization._serialize_keras_object(  # pylint: disable=protected-access
2645        self.normalizer_fn)
2646    config['dtype'] = self.dtype.name
2647    return config
2648
2649  @classmethod
2650  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2651    """See 'FeatureColumn` base class."""
2652    _check_config_keys(config, cls._fields)
2653    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
2654    kwargs = _standardize_and_copy_config(config)
2655    kwargs['normalizer_fn'] = serialization._deserialize_keras_object(  # pylint: disable=protected-access
2656        config['normalizer_fn'], custom_objects=custom_objects)
2657    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2658
2659    return cls(**kwargs)
2660
2661
2662class BucketizedColumn(
2663    DenseColumn,
2664    CategoricalColumn,
2665    fc_old._DenseColumn,  # pylint: disable=protected-access
2666    fc_old._CategoricalColumn,  # pylint: disable=protected-access
2667    collections.namedtuple('BucketizedColumn',
2668                           ('source_column', 'boundaries'))):
2669  """See `bucketized_column`."""
2670
2671  @property
2672  def _is_v2_column(self):
2673    return (isinstance(self.source_column, FeatureColumn) and
2674            self.source_column._is_v2_column)  # pylint: disable=protected-access
2675
2676  @property
2677  def name(self):
2678    """See `FeatureColumn` base class."""
2679    return '{}_bucketized'.format(self.source_column.name)
2680
2681  @property
2682  def parse_example_spec(self):
2683    """See `FeatureColumn` base class."""
2684    return self.source_column.parse_example_spec
2685
2686  @property
2687  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2688                          _FEATURE_COLUMN_DEPRECATION)
2689  def _parse_example_spec(self):
2690    return self.source_column._parse_example_spec  # pylint: disable=protected-access
2691
2692  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2693                          _FEATURE_COLUMN_DEPRECATION)
2694  def _transform_feature(self, inputs):
2695    """Returns bucketized categorical `source_column` tensor."""
2696    source_tensor = inputs.get(self.source_column)
2697    return math_ops._bucketize(  # pylint: disable=protected-access
2698        source_tensor,
2699        boundaries=self.boundaries)
2700
2701  def transform_feature(self, transformation_cache, state_manager):
2702    """Returns bucketized categorical `source_column` tensor."""
2703    source_tensor = transformation_cache.get(self.source_column, state_manager)
2704    return math_ops._bucketize(  # pylint: disable=protected-access
2705        source_tensor,
2706        boundaries=self.boundaries)
2707
2708  @property
2709  def variable_shape(self):
2710    """See `DenseColumn` base class."""
2711    return tensor_shape.TensorShape(
2712        tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
2713
2714  @property
2715  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2716                          _FEATURE_COLUMN_DEPRECATION)
2717  def _variable_shape(self):
2718    return self.variable_shape
2719
2720  def _get_dense_tensor_for_input_tensor(self, input_tensor):
2721    return array_ops.one_hot(
2722        indices=math_ops.cast(input_tensor, dtypes.int64),
2723        depth=len(self.boundaries) + 1,
2724        on_value=1.,
2725        off_value=0.)
2726
2727  def get_dense_tensor(self, transformation_cache, state_manager):
2728    """Returns one hot encoded dense `Tensor`."""
2729    input_tensor = transformation_cache.get(self, state_manager)
2730    return self._get_dense_tensor_for_input_tensor(input_tensor)
2731
2732  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2733                          _FEATURE_COLUMN_DEPRECATION)
2734  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2735    del weight_collections
2736    del trainable
2737    input_tensor = inputs.get(self)
2738    return self._get_dense_tensor_for_input_tensor(input_tensor)
2739
2740  @property
2741  def num_buckets(self):
2742    """See `CategoricalColumn` base class."""
2743    # By construction, source_column is always one-dimensional.
2744    return (len(self.boundaries) + 1) * self.source_column.shape[0]
2745
2746  @property
2747  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2748                          _FEATURE_COLUMN_DEPRECATION)
2749  def _num_buckets(self):
2750    return self.num_buckets
2751
2752  def _get_sparse_tensors_for_input_tensor(self, input_tensor):
2753    batch_size = array_ops.shape(input_tensor)[0]
2754    # By construction, source_column is always one-dimensional.
2755    source_dimension = self.source_column.shape[0]
2756
2757    i1 = array_ops.reshape(
2758        array_ops.tile(
2759            array_ops.expand_dims(math_ops.range(0, batch_size), 1),
2760            [1, source_dimension]),
2761        (-1,))
2762    i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
2763    # Flatten the bucket indices and unique them across dimensions
2764    # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
2765    bucket_indices = (
2766        array_ops.reshape(input_tensor, (-1,)) +
2767        (len(self.boundaries) + 1) * i2)
2768
2769    indices = math_ops.cast(
2770        array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64)
2771    dense_shape = math_ops.cast(
2772        array_ops.stack([batch_size, source_dimension]), dtypes.int64)
2773    sparse_tensor = sparse_tensor_lib.SparseTensor(
2774        indices=indices,
2775        values=bucket_indices,
2776        dense_shape=dense_shape)
2777    return CategoricalColumn.IdWeightPair(sparse_tensor, None)
2778
2779  def get_sparse_tensors(self, transformation_cache, state_manager):
2780    """Converts dense inputs to SparseTensor so downstream code can use it."""
2781    input_tensor = transformation_cache.get(self, state_manager)
2782    return self._get_sparse_tensors_for_input_tensor(input_tensor)
2783
2784  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2785                          _FEATURE_COLUMN_DEPRECATION)
2786  def _get_sparse_tensors(self, inputs, weight_collections=None,
2787                          trainable=None):
2788    """Converts dense inputs to SparseTensor so downstream code can use it."""
2789    del weight_collections
2790    del trainable
2791    input_tensor = inputs.get(self)
2792    return self._get_sparse_tensors_for_input_tensor(input_tensor)
2793
2794  @property
2795  def parents(self):
2796    """See 'FeatureColumn` base class."""
2797    return [self.source_column]
2798
2799  def get_config(self):
2800    """See 'FeatureColumn` base class."""
2801    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
2802    config = dict(zip(self._fields, self))
2803    config['source_column'] = serialize_feature_column(self.source_column)
2804    return config
2805
2806  @classmethod
2807  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2808    """See 'FeatureColumn` base class."""
2809    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
2810    _check_config_keys(config, cls._fields)
2811    kwargs = _standardize_and_copy_config(config)
2812    kwargs['source_column'] = deserialize_feature_column(
2813        config['source_column'], custom_objects, columns_by_name)
2814    return cls(**kwargs)
2815
2816
2817class EmbeddingColumn(
2818    DenseColumn,
2819    SequenceDenseColumn,
2820    fc_old._DenseColumn,  # pylint: disable=protected-access
2821    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
2822    collections.namedtuple(
2823        'EmbeddingColumn',
2824        ('categorical_column', 'dimension', 'combiner', 'initializer',
2825         'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable',
2826         'use_safe_embedding_lookup'))):
2827  """See `embedding_column`."""
2828
2829  def __new__(cls,
2830              categorical_column,
2831              dimension,
2832              combiner,
2833              initializer,
2834              ckpt_to_load_from,
2835              tensor_name_in_ckpt,
2836              max_norm,
2837              trainable,
2838              use_safe_embedding_lookup=True):
2839    return super(EmbeddingColumn, cls).__new__(
2840        cls,
2841        categorical_column=categorical_column,
2842        dimension=dimension,
2843        combiner=combiner,
2844        initializer=initializer,
2845        ckpt_to_load_from=ckpt_to_load_from,
2846        tensor_name_in_ckpt=tensor_name_in_ckpt,
2847        max_norm=max_norm,
2848        trainable=trainable,
2849        use_safe_embedding_lookup=use_safe_embedding_lookup)
2850
2851  @property
2852  def _is_v2_column(self):
2853    return (isinstance(self.categorical_column, FeatureColumn) and
2854            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
2855
2856  @property
2857  def name(self):
2858    """See `FeatureColumn` base class."""
2859    return '{}_embedding'.format(self.categorical_column.name)
2860
2861  @property
2862  def parse_example_spec(self):
2863    """See `FeatureColumn` base class."""
2864    return self.categorical_column.parse_example_spec
2865
2866  @property
2867  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2868                          _FEATURE_COLUMN_DEPRECATION)
2869  def _parse_example_spec(self):
2870    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
2871
2872  def transform_feature(self, transformation_cache, state_manager):
2873    """Transforms underlying `categorical_column`."""
2874    return transformation_cache.get(self.categorical_column, state_manager)
2875
2876  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2877                          _FEATURE_COLUMN_DEPRECATION)
2878  def _transform_feature(self, inputs):
2879    return inputs.get(self.categorical_column)
2880
2881  @property
2882  def variable_shape(self):
2883    """See `DenseColumn` base class."""
2884    return tensor_shape.TensorShape([self.dimension])
2885
2886  @property
2887  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2888                          _FEATURE_COLUMN_DEPRECATION)
2889  def _variable_shape(self):
2890    return self.variable_shape
2891
2892  def create_state(self, state_manager):
2893    """Creates the embedding lookup variable."""
2894    default_num_buckets = (self.categorical_column.num_buckets
2895                           if self._is_v2_column
2896                           else self.categorical_column._num_buckets)   # pylint: disable=protected-access
2897    num_buckets = getattr(self.categorical_column, 'num_buckets',
2898                          default_num_buckets)
2899    embedding_shape = (num_buckets, self.dimension)
2900    state_manager.create_variable(
2901        self,
2902        name='embedding_weights',
2903        shape=embedding_shape,
2904        dtype=dtypes.float32,
2905        trainable=self.trainable,
2906        use_resource=True,
2907        initializer=self.initializer)
2908
2909  def _get_dense_tensor_internal_helper(self, sparse_tensors,
2910                                        embedding_weights):
2911    sparse_ids = sparse_tensors.id_tensor
2912    sparse_weights = sparse_tensors.weight_tensor
2913
2914    if self.ckpt_to_load_from is not None:
2915      to_restore = embedding_weights
2916      if isinstance(to_restore, variables.PartitionedVariable):
2917        to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
2918      checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
2919          self.tensor_name_in_ckpt: to_restore
2920      })
2921
2922    sparse_id_rank = tensor_shape.dimension_value(
2923        sparse_ids.dense_shape.get_shape()[0])
2924    embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
2925    if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
2926        sparse_id_rank <= 2):
2927      embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
2928    # Return embedding lookup result.
2929    return embedding_lookup_sparse(
2930        embedding_weights,
2931        sparse_ids,
2932        sparse_weights,
2933        combiner=self.combiner,
2934        name='%s_weights' % self.name,
2935        max_norm=self.max_norm)
2936
2937  def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
2938    """Private method that follows the signature of get_dense_tensor."""
2939    embedding_weights = state_manager.get_variable(
2940        self, name='embedding_weights')
2941    return self._get_dense_tensor_internal_helper(sparse_tensors,
2942                                                  embedding_weights)
2943
2944  def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
2945                                     trainable):
2946    """Private method that follows the signature of _get_dense_tensor."""
2947    embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
2948    if (weight_collections and
2949        ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
2950      weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
2951    embedding_weights = variable_scope.get_variable(
2952        name='embedding_weights',
2953        shape=embedding_shape,
2954        dtype=dtypes.float32,
2955        initializer=self.initializer,
2956        trainable=self.trainable and trainable,
2957        collections=weight_collections)
2958    return self._get_dense_tensor_internal_helper(sparse_tensors,
2959                                                  embedding_weights)
2960
2961  def get_dense_tensor(self, transformation_cache, state_manager):
2962    """Returns tensor after doing the embedding lookup.
2963
2964    Args:
2965      transformation_cache: A `FeatureTransformationCache` object to access
2966        features.
2967      state_manager: A `StateManager` to create / access resources such as
2968        lookup tables.
2969
2970    Returns:
2971      Embedding lookup tensor.
2972
2973    Raises:
2974      ValueError: `categorical_column` is SequenceCategoricalColumn.
2975    """
2976    if isinstance(self.categorical_column, SequenceCategoricalColumn):
2977      raise ValueError(
2978          'In embedding_column: {}. '
2979          'categorical_column must not be of type SequenceCategoricalColumn. '
2980          'Suggested fix A: If you wish to use DenseFeatures, use a '
2981          'non-sequence categorical_column_with_*. '
2982          'Suggested fix B: If you wish to create sequence input, use '
2983          'SequenceFeatures instead of DenseFeatures. '
2984          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
2985                                       self.categorical_column))
2986    # Get sparse IDs and weights.
2987    sparse_tensors = self.categorical_column.get_sparse_tensors(
2988        transformation_cache, state_manager)
2989    return self._get_dense_tensor_internal(sparse_tensors, state_manager)
2990
2991  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2992                          _FEATURE_COLUMN_DEPRECATION)
2993  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2994    if isinstance(
2995        self.categorical_column,
2996        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
2997      raise ValueError(
2998          'In embedding_column: {}. '
2999          'categorical_column must not be of type _SequenceCategoricalColumn. '
3000          'Suggested fix A: If you wish to use DenseFeatures, use a '
3001          'non-sequence categorical_column_with_*. '
3002          'Suggested fix B: If you wish to create sequence input, use '
3003          'SequenceFeatures instead of DenseFeatures. '
3004          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3005                                       self.categorical_column))
3006    sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
3007        inputs, weight_collections, trainable)
3008    return self._old_get_dense_tensor_internal(sparse_tensors,
3009                                               weight_collections, trainable)
3010
3011  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3012    """See `SequenceDenseColumn` base class."""
3013    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3014      raise ValueError(
3015          'In embedding_column: {}. '
3016          'categorical_column must be of type SequenceCategoricalColumn '
3017          'to use SequenceFeatures. '
3018          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3019          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3020                                       self.categorical_column))
3021    sparse_tensors = self.categorical_column.get_sparse_tensors(
3022        transformation_cache, state_manager)
3023    dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
3024                                                   state_manager)
3025    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3026        sparse_tensors.id_tensor)
3027    return SequenceDenseColumn.TensorSequenceLengthPair(
3028        dense_tensor=dense_tensor, sequence_length=sequence_length)
3029
3030  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3031                          _FEATURE_COLUMN_DEPRECATION)
3032  def _get_sequence_dense_tensor(self,
3033                                 inputs,
3034                                 weight_collections=None,
3035                                 trainable=None):
3036    if not isinstance(
3037        self.categorical_column,
3038        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
3039      raise ValueError(
3040          'In embedding_column: {}. '
3041          'categorical_column must be of type SequenceCategoricalColumn '
3042          'to use SequenceFeatures. '
3043          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3044          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3045                                       self.categorical_column))
3046    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3047    dense_tensor = self._old_get_dense_tensor_internal(
3048        sparse_tensors,
3049        weight_collections=weight_collections,
3050        trainable=trainable)
3051    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3052        sparse_tensors.id_tensor)
3053    return SequenceDenseColumn.TensorSequenceLengthPair(
3054        dense_tensor=dense_tensor, sequence_length=sequence_length)
3055
3056  @property
3057  def parents(self):
3058    """See 'FeatureColumn` base class."""
3059    return [self.categorical_column]
3060
3061  def get_config(self):
3062    """See 'FeatureColumn` base class."""
3063    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
3064    config = dict(zip(self._fields, self))
3065    config['categorical_column'] = serialization.serialize_feature_column(
3066        self.categorical_column)
3067    config['initializer'] = serialization._serialize_keras_object(  # pylint: disable=protected-access
3068        self.initializer)
3069    return config
3070
3071  @classmethod
3072  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3073    """See 'FeatureColumn` base class."""
3074    if 'use_safe_embedding_lookup' not in config:
3075      config['use_safe_embedding_lookup'] = True
3076    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
3077    _check_config_keys(config, cls._fields)
3078    kwargs = _standardize_and_copy_config(config)
3079    kwargs['categorical_column'] = serialization.deserialize_feature_column(
3080        config['categorical_column'], custom_objects, columns_by_name)
3081    all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
3082    kwargs['initializer'] = serialization._deserialize_keras_object(  # pylint: disable=protected-access
3083        config['initializer'],
3084        module_objects=all_initializers,
3085        custom_objects=custom_objects)
3086    return cls(**kwargs)
3087
3088
3089def _raise_shared_embedding_column_error():
3090  raise ValueError('SharedEmbeddingColumns are not supported in '
3091                   '`linear_model` or `input_layer`. Please use '
3092                   '`DenseFeatures` or `LinearModel` instead.')
3093
3094
3095class SharedEmbeddingColumnCreator(tracking.AutoTrackable):
3096
3097  def __init__(self,
3098               dimension,
3099               initializer,
3100               ckpt_to_load_from,
3101               tensor_name_in_ckpt,
3102               num_buckets,
3103               trainable,
3104               name='shared_embedding_column_creator',
3105               use_safe_embedding_lookup=True):
3106    self._dimension = dimension
3107    self._initializer = initializer
3108    self._ckpt_to_load_from = ckpt_to_load_from
3109    self._tensor_name_in_ckpt = tensor_name_in_ckpt
3110    self._num_buckets = num_buckets
3111    self._trainable = trainable
3112    self._name = name
3113    self._use_safe_embedding_lookup = use_safe_embedding_lookup
3114    # Map from graph keys to embedding_weight variables.
3115    self._embedding_weights = {}
3116
3117  def __call__(self, categorical_column, combiner, max_norm):
3118    return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm,
3119                                 self._use_safe_embedding_lookup)
3120
3121  @property
3122  def embedding_weights(self):
3123    key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
3124    if key not in self._embedding_weights:
3125      embedding_shape = (self._num_buckets, self._dimension)
3126      var = variable_scope.get_variable(
3127          name=self._name,
3128          shape=embedding_shape,
3129          dtype=dtypes.float32,
3130          initializer=self._initializer,
3131          trainable=self._trainable)
3132
3133      if self._ckpt_to_load_from is not None:
3134        to_restore = var
3135        if isinstance(to_restore, variables.PartitionedVariable):
3136          to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
3137        checkpoint_utils.init_from_checkpoint(
3138            self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore})
3139      self._embedding_weights[key] = var
3140    return self._embedding_weights[key]
3141
3142  @property
3143  def dimension(self):
3144    return self._dimension
3145
3146
3147class SharedEmbeddingColumn(
3148    DenseColumn,
3149    SequenceDenseColumn,
3150    fc_old._DenseColumn,  # pylint: disable=protected-access
3151    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
3152    collections.namedtuple(
3153        'SharedEmbeddingColumn',
3154        ('categorical_column', 'shared_embedding_column_creator', 'combiner',
3155         'max_norm', 'use_safe_embedding_lookup'))):
3156  """See `embedding_column`."""
3157
3158  def __new__(cls,
3159              categorical_column,
3160              shared_embedding_column_creator,
3161              combiner,
3162              max_norm,
3163              use_safe_embedding_lookup=True):
3164    return super(SharedEmbeddingColumn, cls).__new__(
3165        cls,
3166        categorical_column=categorical_column,
3167        shared_embedding_column_creator=shared_embedding_column_creator,
3168        combiner=combiner,
3169        max_norm=max_norm,
3170        use_safe_embedding_lookup=use_safe_embedding_lookup)
3171
3172  @property
3173  def _is_v2_column(self):
3174    return True
3175
3176  @property
3177  def name(self):
3178    """See `FeatureColumn` base class."""
3179    return '{}_shared_embedding'.format(self.categorical_column.name)
3180
3181  @property
3182  def parse_example_spec(self):
3183    """See `FeatureColumn` base class."""
3184    return self.categorical_column.parse_example_spec
3185
3186  @property
3187  def _parse_example_spec(self):
3188    return _raise_shared_embedding_column_error()
3189
3190  def transform_feature(self, transformation_cache, state_manager):
3191    """See `FeatureColumn` base class."""
3192    return transformation_cache.get(self.categorical_column, state_manager)
3193
3194  def _transform_feature(self, inputs):
3195    return _raise_shared_embedding_column_error()
3196
3197  @property
3198  def variable_shape(self):
3199    """See `DenseColumn` base class."""
3200    return tensor_shape.TensorShape(
3201        [self.shared_embedding_column_creator.dimension])
3202
3203  @property
3204  def _variable_shape(self):
3205    return _raise_shared_embedding_column_error()
3206
3207  def _get_dense_tensor_internal(self, transformation_cache, state_manager):
3208    """Private method that follows the signature of _get_dense_tensor."""
3209    # This method is called from a variable_scope with name _var_scope_name,
3210    # which is shared among all shared embeddings. Open a name_scope here, so
3211    # that the ops for different columns have distinct names.
3212    with ops.name_scope(None, default_name=self.name):
3213      # Get sparse IDs and weights.
3214      sparse_tensors = self.categorical_column.get_sparse_tensors(
3215          transformation_cache, state_manager)
3216      sparse_ids = sparse_tensors.id_tensor
3217      sparse_weights = sparse_tensors.weight_tensor
3218
3219      embedding_weights = self.shared_embedding_column_creator.embedding_weights
3220
3221      sparse_id_rank = tensor_shape.dimension_value(
3222          sparse_ids.dense_shape.get_shape()[0])
3223      embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
3224      if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
3225          sparse_id_rank <= 2):
3226        embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
3227      # Return embedding lookup result.
3228      return embedding_lookup_sparse(
3229          embedding_weights,
3230          sparse_ids,
3231          sparse_weights,
3232          combiner=self.combiner,
3233          name='%s_weights' % self.name,
3234          max_norm=self.max_norm)
3235
3236  def get_dense_tensor(self, transformation_cache, state_manager):
3237    """Returns the embedding lookup result."""
3238    if isinstance(self.categorical_column, SequenceCategoricalColumn):
3239      raise ValueError(
3240          'In embedding_column: {}. '
3241          'categorical_column must not be of type SequenceCategoricalColumn. '
3242          'Suggested fix A: If you wish to use DenseFeatures, use a '
3243          'non-sequence categorical_column_with_*. '
3244          'Suggested fix B: If you wish to create sequence input, use '
3245          'SequenceFeatures instead of DenseFeatures. '
3246          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3247                                       self.categorical_column))
3248    return self._get_dense_tensor_internal(transformation_cache, state_manager)
3249
3250  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3251    return _raise_shared_embedding_column_error()
3252
3253  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3254    """See `SequenceDenseColumn` base class."""
3255    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3256      raise ValueError(
3257          'In embedding_column: {}. '
3258          'categorical_column must be of type SequenceCategoricalColumn '
3259          'to use SequenceFeatures. '
3260          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3261          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3262                                       self.categorical_column))
3263    dense_tensor = self._get_dense_tensor_internal(transformation_cache,
3264                                                   state_manager)
3265    sparse_tensors = self.categorical_column.get_sparse_tensors(
3266        transformation_cache, state_manager)
3267    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3268        sparse_tensors.id_tensor)
3269    return SequenceDenseColumn.TensorSequenceLengthPair(
3270        dense_tensor=dense_tensor, sequence_length=sequence_length)
3271
3272  def _get_sequence_dense_tensor(self,
3273                                 inputs,
3274                                 weight_collections=None,
3275                                 trainable=None):
3276    return _raise_shared_embedding_column_error()
3277
3278  @property
3279  def parents(self):
3280    """See 'FeatureColumn` base class."""
3281    return [self.categorical_column]
3282
3283
3284def _check_shape(shape, key):
3285  """Returns shape if it's valid, raises error otherwise."""
3286  assert shape is not None
3287  if not nest.is_sequence(shape):
3288    shape = [shape]
3289  shape = tuple(shape)
3290  for dimension in shape:
3291    if not isinstance(dimension, int):
3292      raise TypeError('shape dimensions must be integer. '
3293                      'shape: {}, key: {}'.format(shape, key))
3294    if dimension < 1:
3295      raise ValueError('shape dimensions must be greater than 0. '
3296                       'shape: {}, key: {}'.format(shape, key))
3297  return shape
3298
3299
3300class HashedCategoricalColumn(
3301    CategoricalColumn,
3302    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3303    collections.namedtuple('HashedCategoricalColumn',
3304                           ('key', 'hash_bucket_size', 'dtype'))):
3305  """see `categorical_column_with_hash_bucket`."""
3306
3307  @property
3308  def _is_v2_column(self):
3309    return True
3310
3311  @property
3312  def name(self):
3313    """See `FeatureColumn` base class."""
3314    return self.key
3315
3316  @property
3317  def parse_example_spec(self):
3318    """See `FeatureColumn` base class."""
3319    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3320
3321  @property
3322  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3323                          _FEATURE_COLUMN_DEPRECATION)
3324  def _parse_example_spec(self):
3325    return self.parse_example_spec
3326
3327  def _transform_input_tensor(self, input_tensor):
3328    """Hashes the values in the feature_column."""
3329    if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
3330      raise ValueError('SparseColumn input must be a SparseTensor.')
3331
3332    fc_utils.assert_string_or_int(
3333        input_tensor.dtype,
3334        prefix='column_name: {} input_tensor'.format(self.key))
3335
3336    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3337      raise ValueError(
3338          'Column dtype and SparseTensors dtype must be compatible. '
3339          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3340              self.key, self.dtype, input_tensor.dtype))
3341
3342    if self.dtype == dtypes.string:
3343      sparse_values = input_tensor.values
3344    else:
3345      sparse_values = string_ops.as_string(input_tensor.values)
3346
3347    sparse_id_values = string_ops.string_to_hash_bucket_fast(
3348        sparse_values, self.hash_bucket_size, name='lookup')
3349    return sparse_tensor_lib.SparseTensor(
3350        input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
3351
3352  def transform_feature(self, transformation_cache, state_manager):
3353    """Hashes the values in the feature_column."""
3354    input_tensor = _to_sparse_input_and_drop_ignore_values(
3355        transformation_cache.get(self.key, state_manager))
3356    return self._transform_input_tensor(input_tensor)
3357
3358  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3359                          _FEATURE_COLUMN_DEPRECATION)
3360  def _transform_feature(self, inputs):
3361    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3362    return self._transform_input_tensor(input_tensor)
3363
3364  @property
3365  def num_buckets(self):
3366    """Returns number of buckets in this sparse feature."""
3367    return self.hash_bucket_size
3368
3369  @property
3370  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3371                          _FEATURE_COLUMN_DEPRECATION)
3372  def _num_buckets(self):
3373    return self.num_buckets
3374
3375  def get_sparse_tensors(self, transformation_cache, state_manager):
3376    """See `CategoricalColumn` base class."""
3377    return CategoricalColumn.IdWeightPair(
3378        transformation_cache.get(self, state_manager), None)
3379
3380  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3381                          _FEATURE_COLUMN_DEPRECATION)
3382  def _get_sparse_tensors(self, inputs, weight_collections=None,
3383                          trainable=None):
3384    del weight_collections
3385    del trainable
3386    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3387
3388  @property
3389  def parents(self):
3390    """See 'FeatureColumn` base class."""
3391    return [self.key]
3392
3393  def get_config(self):
3394    """See 'FeatureColumn` base class."""
3395    config = dict(zip(self._fields, self))
3396    config['dtype'] = self.dtype.name
3397    return config
3398
3399  @classmethod
3400  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3401    """See 'FeatureColumn` base class."""
3402    _check_config_keys(config, cls._fields)
3403    kwargs = _standardize_and_copy_config(config)
3404    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3405    return cls(**kwargs)
3406
3407
3408class VocabularyFileCategoricalColumn(
3409    CategoricalColumn,
3410    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3411    collections.namedtuple('VocabularyFileCategoricalColumn',
3412                           ('key', 'vocabulary_file', 'vocabulary_size',
3413                            'num_oov_buckets', 'dtype', 'default_value'))):
3414  """See `categorical_column_with_vocabulary_file`."""
3415
3416  @property
3417  def _is_v2_column(self):
3418    return True
3419
3420  @property
3421  def name(self):
3422    """See `FeatureColumn` base class."""
3423    return self.key
3424
3425  @property
3426  def parse_example_spec(self):
3427    """See `FeatureColumn` base class."""
3428    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3429
3430  @property
3431  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3432                          _FEATURE_COLUMN_DEPRECATION)
3433  def _parse_example_spec(self):
3434    return self.parse_example_spec
3435
3436  def _transform_input_tensor(self, input_tensor, state_manager=None):
3437    """Creates a lookup table for the vocabulary."""
3438    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3439      raise ValueError(
3440          'Column dtype and SparseTensors dtype must be compatible. '
3441          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3442              self.key, self.dtype, input_tensor.dtype))
3443
3444    fc_utils.assert_string_or_int(
3445        input_tensor.dtype,
3446        prefix='column_name: {} input_tensor'.format(self.key))
3447
3448    key_dtype = self.dtype
3449    if input_tensor.dtype.is_integer:
3450      # `index_table_from_file` requires 64-bit integer keys.
3451      key_dtype = dtypes.int64
3452      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3453
3454    name = '{}_lookup'.format(self.key)
3455    if state_manager is None or not state_manager.has_resource(self, name):
3456      with ops.init_scope():
3457        table = lookup_ops.index_table_from_file(
3458            vocabulary_file=self.vocabulary_file,
3459            num_oov_buckets=self.num_oov_buckets,
3460            vocab_size=self.vocabulary_size,
3461            default_value=self.default_value,
3462            key_dtype=key_dtype,
3463            name=name)
3464      if state_manager is not None:
3465        state_manager.add_resource(self, name, table)
3466    else:
3467      # Reuse the table from the previous run.
3468      table = state_manager.get_resource(self, name)
3469    return table.lookup(input_tensor)
3470
3471  def transform_feature(self, transformation_cache, state_manager):
3472    """Creates a lookup table for the vocabulary."""
3473    input_tensor = _to_sparse_input_and_drop_ignore_values(
3474        transformation_cache.get(self.key, state_manager))
3475    return self._transform_input_tensor(input_tensor, state_manager)
3476
3477  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3478                          _FEATURE_COLUMN_DEPRECATION)
3479  def _transform_feature(self, inputs):
3480    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3481    return self._transform_input_tensor(input_tensor)
3482
3483  @property
3484  def num_buckets(self):
3485    """Returns number of buckets in this sparse feature."""
3486    return self.vocabulary_size + self.num_oov_buckets
3487
3488  @property
3489  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3490                          _FEATURE_COLUMN_DEPRECATION)
3491  def _num_buckets(self):
3492    return self.num_buckets
3493
3494  def get_sparse_tensors(self, transformation_cache, state_manager):
3495    """See `CategoricalColumn` base class."""
3496    return CategoricalColumn.IdWeightPair(
3497        transformation_cache.get(self, state_manager), None)
3498
3499  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3500                          _FEATURE_COLUMN_DEPRECATION)
3501  def _get_sparse_tensors(self, inputs, weight_collections=None,
3502                          trainable=None):
3503    del weight_collections
3504    del trainable
3505    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3506
3507  @property
3508  def parents(self):
3509    """See 'FeatureColumn` base class."""
3510    return [self.key]
3511
3512  def get_config(self):
3513    """See 'FeatureColumn` base class."""
3514    config = dict(zip(self._fields, self))
3515    config['dtype'] = self.dtype.name
3516    return config
3517
3518  @classmethod
3519  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3520    """See 'FeatureColumn` base class."""
3521    _check_config_keys(config, cls._fields)
3522    kwargs = _standardize_and_copy_config(config)
3523    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3524    return cls(**kwargs)
3525
3526
3527class VocabularyListCategoricalColumn(
3528    CategoricalColumn,
3529    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3530    collections.namedtuple(
3531        'VocabularyListCategoricalColumn',
3532        ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
3533):
3534  """See `categorical_column_with_vocabulary_list`."""
3535
3536  @property
3537  def _is_v2_column(self):
3538    return True
3539
3540  @property
3541  def name(self):
3542    """See `FeatureColumn` base class."""
3543    return self.key
3544
3545  @property
3546  def parse_example_spec(self):
3547    """See `FeatureColumn` base class."""
3548    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3549
3550  @property
3551  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3552                          _FEATURE_COLUMN_DEPRECATION)
3553  def _parse_example_spec(self):
3554    return self.parse_example_spec
3555
3556  def _transform_input_tensor(self, input_tensor, state_manager=None):
3557    """Creates a lookup table for the vocabulary list."""
3558    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3559      raise ValueError(
3560          'Column dtype and SparseTensors dtype must be compatible. '
3561          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3562              self.key, self.dtype, input_tensor.dtype))
3563
3564    fc_utils.assert_string_or_int(
3565        input_tensor.dtype,
3566        prefix='column_name: {} input_tensor'.format(self.key))
3567
3568    key_dtype = self.dtype
3569    if input_tensor.dtype.is_integer:
3570      # `index_table_from_tensor` requires 64-bit integer keys.
3571      key_dtype = dtypes.int64
3572      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3573
3574    name = '{}_lookup'.format(self.key)
3575    if state_manager is None or not state_manager.has_resource(self, name):
3576      with ops.init_scope():
3577        table = lookup_ops.index_table_from_tensor(
3578            vocabulary_list=tuple(self.vocabulary_list),
3579            default_value=self.default_value,
3580            num_oov_buckets=self.num_oov_buckets,
3581            dtype=key_dtype,
3582            name=name)
3583      if state_manager is not None:
3584        state_manager.add_resource(self, name, table)
3585    else:
3586      # Reuse the table from the previous run.
3587      table = state_manager.get_resource(self, name)
3588    return table.lookup(input_tensor)
3589
3590  def transform_feature(self, transformation_cache, state_manager):
3591    """Creates a lookup table for the vocabulary list."""
3592    input_tensor = _to_sparse_input_and_drop_ignore_values(
3593        transformation_cache.get(self.key, state_manager))
3594    return self._transform_input_tensor(input_tensor, state_manager)
3595
3596  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3597                          _FEATURE_COLUMN_DEPRECATION)
3598  def _transform_feature(self, inputs):
3599    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3600    return self._transform_input_tensor(input_tensor)
3601
3602  @property
3603  def num_buckets(self):
3604    """Returns number of buckets in this sparse feature."""
3605    return len(self.vocabulary_list) + self.num_oov_buckets
3606
3607  @property
3608  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3609                          _FEATURE_COLUMN_DEPRECATION)
3610  def _num_buckets(self):
3611    return self.num_buckets
3612
3613  def get_sparse_tensors(self, transformation_cache, state_manager):
3614    """See `CategoricalColumn` base class."""
3615    return CategoricalColumn.IdWeightPair(
3616        transformation_cache.get(self, state_manager), None)
3617
3618  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3619                          _FEATURE_COLUMN_DEPRECATION)
3620  def _get_sparse_tensors(self, inputs, weight_collections=None,
3621                          trainable=None):
3622    del weight_collections
3623    del trainable
3624    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3625
3626  @property
3627  def parents(self):
3628    """See 'FeatureColumn` base class."""
3629    return [self.key]
3630
3631  def get_config(self):
3632    """See 'FeatureColumn` base class."""
3633    config = dict(zip(self._fields, self))
3634    config['dtype'] = self.dtype.name
3635    return config
3636
3637  @classmethod
3638  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3639    """See 'FeatureColumn` base class."""
3640    _check_config_keys(config, cls._fields)
3641    kwargs = _standardize_and_copy_config(config)
3642    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3643    return cls(**kwargs)
3644
3645
3646class IdentityCategoricalColumn(
3647    CategoricalColumn,
3648    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3649    collections.namedtuple('IdentityCategoricalColumn',
3650                           ('key', 'number_buckets', 'default_value'))):
3651
3652  """See `categorical_column_with_identity`."""
3653
3654  @property
3655  def _is_v2_column(self):
3656    return True
3657
3658  @property
3659  def name(self):
3660    """See `FeatureColumn` base class."""
3661    return self.key
3662
3663  @property
3664  def parse_example_spec(self):
3665    """See `FeatureColumn` base class."""
3666    return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
3667
3668  @property
3669  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3670                          _FEATURE_COLUMN_DEPRECATION)
3671  def _parse_example_spec(self):
3672    return self.parse_example_spec
3673
3674  def _transform_input_tensor(self, input_tensor):
3675    """Returns a SparseTensor with identity values."""
3676    if not input_tensor.dtype.is_integer:
3677      raise ValueError(
3678          'Invalid input, not integer. key: {} dtype: {}'.format(
3679              self.key, input_tensor.dtype))
3680    values = input_tensor.values
3681    if input_tensor.values.dtype != dtypes.int64:
3682      values = math_ops.cast(values, dtypes.int64, name='values')
3683    if self.default_value is not None:
3684      values = math_ops.cast(input_tensor.values, dtypes.int64, name='values')
3685      num_buckets = math_ops.cast(
3686          self.num_buckets, dtypes.int64, name='num_buckets')
3687      zero = math_ops.cast(0, dtypes.int64, name='zero')
3688      # Assign default for out-of-range values.
3689      values = array_ops.where_v2(
3690          math_ops.logical_or(
3691              values < zero, values >= num_buckets, name='out_of_range'),
3692          array_ops.fill(
3693              dims=array_ops.shape(values),
3694              value=math_ops.cast(self.default_value, dtypes.int64),
3695              name='default_values'), values)
3696
3697    return sparse_tensor_lib.SparseTensor(
3698        indices=input_tensor.indices,
3699        values=values,
3700        dense_shape=input_tensor.dense_shape)
3701
3702  def transform_feature(self, transformation_cache, state_manager):
3703    """Returns a SparseTensor with identity values."""
3704    input_tensor = _to_sparse_input_and_drop_ignore_values(
3705        transformation_cache.get(self.key, state_manager))
3706    return self._transform_input_tensor(input_tensor)
3707
3708  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3709                          _FEATURE_COLUMN_DEPRECATION)
3710  def _transform_feature(self, inputs):
3711    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3712    return self._transform_input_tensor(input_tensor)
3713
3714  @property
3715  def num_buckets(self):
3716    """Returns number of buckets in this sparse feature."""
3717    return self.number_buckets
3718
3719  @property
3720  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3721                          _FEATURE_COLUMN_DEPRECATION)
3722  def _num_buckets(self):
3723    return self.num_buckets
3724
3725  def get_sparse_tensors(self, transformation_cache, state_manager):
3726    """See `CategoricalColumn` base class."""
3727    return CategoricalColumn.IdWeightPair(
3728        transformation_cache.get(self, state_manager), None)
3729
3730  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3731                          _FEATURE_COLUMN_DEPRECATION)
3732  def _get_sparse_tensors(self, inputs, weight_collections=None,
3733                          trainable=None):
3734    del weight_collections
3735    del trainable
3736    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3737
3738  @property
3739  def parents(self):
3740    """See 'FeatureColumn` base class."""
3741    return [self.key]
3742
3743  def get_config(self):
3744    """See 'FeatureColumn` base class."""
3745    return dict(zip(self._fields, self))
3746
3747  @classmethod
3748  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3749    """See 'FeatureColumn` base class."""
3750    _check_config_keys(config, cls._fields)
3751    kwargs = _standardize_and_copy_config(config)
3752    return cls(**kwargs)
3753
3754
3755class WeightedCategoricalColumn(
3756    CategoricalColumn,
3757    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3758    collections.namedtuple(
3759        'WeightedCategoricalColumn',
3760        ('categorical_column', 'weight_feature_key', 'dtype'))):
3761  """See `weighted_categorical_column`."""
3762
3763  @property
3764  def _is_v2_column(self):
3765    return (isinstance(self.categorical_column, FeatureColumn) and
3766            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
3767
3768  @property
3769  def name(self):
3770    """See `FeatureColumn` base class."""
3771    return '{}_weighted_by_{}'.format(
3772        self.categorical_column.name, self.weight_feature_key)
3773
3774  @property
3775  def parse_example_spec(self):
3776    """See `FeatureColumn` base class."""
3777    config = self.categorical_column.parse_example_spec
3778    if self.weight_feature_key in config:
3779      raise ValueError('Parse config {} already exists for {}.'.format(
3780          config[self.weight_feature_key], self.weight_feature_key))
3781    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3782    return config
3783
3784  @property
3785  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3786                          _FEATURE_COLUMN_DEPRECATION)
3787  def _parse_example_spec(self):
3788    config = self.categorical_column._parse_example_spec  # pylint: disable=protected-access
3789    if self.weight_feature_key in config:
3790      raise ValueError('Parse config {} already exists for {}.'.format(
3791          config[self.weight_feature_key], self.weight_feature_key))
3792    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3793    return config
3794
3795  @property
3796  def num_buckets(self):
3797    """See `DenseColumn` base class."""
3798    return self.categorical_column.num_buckets
3799
3800  @property
3801  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3802                          _FEATURE_COLUMN_DEPRECATION)
3803  def _num_buckets(self):
3804    return self.categorical_column._num_buckets  # pylint: disable=protected-access
3805
3806  def _transform_weight_tensor(self, weight_tensor):
3807    if weight_tensor is None:
3808      raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
3809    weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
3810        weight_tensor)
3811    if self.dtype != weight_tensor.dtype.base_dtype:
3812      raise ValueError('Bad dtype, expected {}, but got {}.'.format(
3813          self.dtype, weight_tensor.dtype))
3814    if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
3815      # The weight tensor can be a regular Tensor. In this case, sparsify it.
3816      weight_tensor = _to_sparse_input_and_drop_ignore_values(
3817          weight_tensor, ignore_value=0.0)
3818    if not weight_tensor.dtype.is_floating:
3819      weight_tensor = math_ops.cast(weight_tensor, dtypes.float32)
3820    return weight_tensor
3821
3822  def transform_feature(self, transformation_cache, state_manager):
3823    """Applies weights to tensor generated from `categorical_column`'."""
3824    weight_tensor = transformation_cache.get(self.weight_feature_key,
3825                                             state_manager)
3826    sparse_weight_tensor = self._transform_weight_tensor(weight_tensor)
3827    sparse_categorical_tensor = _to_sparse_input_and_drop_ignore_values(
3828        transformation_cache.get(self.categorical_column, state_manager))
3829    return (sparse_categorical_tensor, sparse_weight_tensor)
3830
3831  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3832                          _FEATURE_COLUMN_DEPRECATION)
3833  def _transform_feature(self, inputs):
3834    """Applies weights to tensor generated from `categorical_column`'."""
3835    weight_tensor = inputs.get(self.weight_feature_key)
3836    weight_tensor = self._transform_weight_tensor(weight_tensor)
3837    return (inputs.get(self.categorical_column), weight_tensor)
3838
3839  def get_sparse_tensors(self, transformation_cache, state_manager):
3840    """See `CategoricalColumn` base class."""
3841    tensors = transformation_cache.get(self, state_manager)
3842    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
3843
3844  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3845                          _FEATURE_COLUMN_DEPRECATION)
3846  def _get_sparse_tensors(self, inputs, weight_collections=None,
3847                          trainable=None):
3848    del weight_collections
3849    del trainable
3850    tensors = inputs.get(self)
3851    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
3852
3853  @property
3854  def parents(self):
3855    """See 'FeatureColumn` base class."""
3856    return [self.categorical_column, self.weight_feature_key]
3857
3858  def get_config(self):
3859    """See 'FeatureColumn` base class."""
3860    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
3861    config = dict(zip(self._fields, self))
3862    config['categorical_column'] = serialize_feature_column(
3863        self.categorical_column)
3864    config['dtype'] = self.dtype.name
3865    return config
3866
3867  @classmethod
3868  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3869    """See 'FeatureColumn` base class."""
3870    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
3871    _check_config_keys(config, cls._fields)
3872    kwargs = _standardize_and_copy_config(config)
3873    kwargs['categorical_column'] = deserialize_feature_column(
3874        config['categorical_column'], custom_objects, columns_by_name)
3875    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3876    return cls(**kwargs)
3877
3878
3879class CrossedColumn(
3880    CategoricalColumn,
3881    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3882    collections.namedtuple('CrossedColumn',
3883                           ('keys', 'hash_bucket_size', 'hash_key'))):
3884  """See `crossed_column`."""
3885
3886  @property
3887  def _is_v2_column(self):
3888    for key in _collect_leaf_level_keys(self):
3889      if isinstance(key, six.string_types):
3890        continue
3891      if not isinstance(key, FeatureColumn):
3892        return False
3893      if not key._is_v2_column:  # pylint: disable=protected-access
3894        return False
3895    return True
3896
3897  @property
3898  def name(self):
3899    """See `FeatureColumn` base class."""
3900    feature_names = []
3901    for key in _collect_leaf_level_keys(self):
3902      if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)):  # pylint: disable=protected-access
3903        feature_names.append(key.name)
3904      else:  # key must be a string
3905        feature_names.append(key)
3906    return '_X_'.join(sorted(feature_names))
3907
3908  @property
3909  def parse_example_spec(self):
3910    """See `FeatureColumn` base class."""
3911    config = {}
3912    for key in self.keys:
3913      if isinstance(key, FeatureColumn):
3914        config.update(key.parse_example_spec)
3915      elif isinstance(key, fc_old._FeatureColumn):  # pylint: disable=protected-access
3916        config.update(key._parse_example_spec)  # pylint: disable=protected-access
3917      else:  # key must be a string
3918        config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
3919    return config
3920
3921  @property
3922  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3923                          _FEATURE_COLUMN_DEPRECATION)
3924  def _parse_example_spec(self):
3925    return self.parse_example_spec
3926
3927  def transform_feature(self, transformation_cache, state_manager):
3928    """Generates a hashed sparse cross from the input tensors."""
3929    feature_tensors = []
3930    for key in _collect_leaf_level_keys(self):
3931      if isinstance(key, six.string_types):
3932        feature_tensors.append(transformation_cache.get(key, state_manager))
3933      elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)):  # pylint: disable=protected-access
3934        ids_and_weights = key.get_sparse_tensors(transformation_cache,
3935                                                 state_manager)
3936        if ids_and_weights.weight_tensor is not None:
3937          raise ValueError(
3938              'crossed_column does not support weight_tensor, but the given '
3939              'column populates weight_tensor. '
3940              'Given column: {}'.format(key.name))
3941        feature_tensors.append(ids_and_weights.id_tensor)
3942      else:
3943        raise ValueError('Unsupported column type. Given: {}'.format(key))
3944    return sparse_ops.sparse_cross_hashed(
3945        inputs=feature_tensors,
3946        num_buckets=self.hash_bucket_size,
3947        hash_key=self.hash_key)
3948
3949  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3950                          _FEATURE_COLUMN_DEPRECATION)
3951  def _transform_feature(self, inputs):
3952    """Generates a hashed sparse cross from the input tensors."""
3953    feature_tensors = []
3954    for key in _collect_leaf_level_keys(self):
3955      if isinstance(key, six.string_types):
3956        feature_tensors.append(inputs.get(key))
3957      elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)):  # pylint: disable=protected-access
3958        ids_and_weights = key._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3959        if ids_and_weights.weight_tensor is not None:
3960          raise ValueError(
3961              'crossed_column does not support weight_tensor, but the given '
3962              'column populates weight_tensor. '
3963              'Given column: {}'.format(key.name))
3964        feature_tensors.append(ids_and_weights.id_tensor)
3965      else:
3966        raise ValueError('Unsupported column type. Given: {}'.format(key))
3967    return sparse_ops.sparse_cross_hashed(
3968        inputs=feature_tensors,
3969        num_buckets=self.hash_bucket_size,
3970        hash_key=self.hash_key)
3971
3972  @property
3973  def num_buckets(self):
3974    """Returns number of buckets in this sparse feature."""
3975    return self.hash_bucket_size
3976
3977  @property
3978  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3979                          _FEATURE_COLUMN_DEPRECATION)
3980  def _num_buckets(self):
3981    return self.num_buckets
3982
3983  def get_sparse_tensors(self, transformation_cache, state_manager):
3984    """See `CategoricalColumn` base class."""
3985    return CategoricalColumn.IdWeightPair(
3986        transformation_cache.get(self, state_manager), None)
3987
3988  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3989                          _FEATURE_COLUMN_DEPRECATION)
3990  def _get_sparse_tensors(self, inputs, weight_collections=None,
3991                          trainable=None):
3992    """See `CategoricalColumn` base class."""
3993    del weight_collections
3994    del trainable
3995    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3996
3997  @property
3998  def parents(self):
3999    """See 'FeatureColumn` base class."""
4000    return list(self.keys)
4001
4002  def get_config(self):
4003    """See 'FeatureColumn` base class."""
4004    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4005    config = dict(zip(self._fields, self))
4006    config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys])
4007    return config
4008
4009  @classmethod
4010  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4011    """See 'FeatureColumn` base class."""
4012    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4013    _check_config_keys(config, cls._fields)
4014    kwargs = _standardize_and_copy_config(config)
4015    kwargs['keys'] = tuple([
4016        deserialize_feature_column(c, custom_objects, columns_by_name)
4017        for c in config['keys']
4018    ])
4019    return cls(**kwargs)
4020
4021
4022def _collect_leaf_level_keys(cross):
4023  """Collects base keys by expanding all nested crosses.
4024
4025  Args:
4026    cross: A `CrossedColumn`.
4027
4028  Returns:
4029    A list of strings or `CategoricalColumn` instances.
4030  """
4031  leaf_level_keys = []
4032  for k in cross.keys:
4033    if isinstance(k, CrossedColumn):
4034      leaf_level_keys.extend(_collect_leaf_level_keys(k))
4035    else:
4036      leaf_level_keys.append(k)
4037  return leaf_level_keys
4038
4039
4040def _prune_invalid_ids(sparse_ids, sparse_weights):
4041  """Prune invalid IDs (< 0) from the input ids and weights."""
4042  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
4043  if sparse_weights is not None:
4044    is_id_valid = math_ops.logical_and(
4045        is_id_valid,
4046        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
4047  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
4048  if sparse_weights is not None:
4049    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
4050  return sparse_ids, sparse_weights
4051
4052
4053def _prune_invalid_weights(sparse_ids, sparse_weights):
4054  """Prune invalid weights (< 0) from the input ids and weights."""
4055  if sparse_weights is not None:
4056    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
4057    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
4058    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
4059  return sparse_ids, sparse_weights
4060
4061
4062class IndicatorColumn(
4063    DenseColumn,
4064    SequenceDenseColumn,
4065    fc_old._DenseColumn,  # pylint: disable=protected-access
4066    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
4067    collections.namedtuple('IndicatorColumn', ('categorical_column'))):
4068  """Represents a one-hot column for use in deep networks.
4069
4070  Args:
4071    categorical_column: A `CategoricalColumn` which is created by
4072      `categorical_column_with_*` function.
4073  """
4074
4075  @property
4076  def _is_v2_column(self):
4077    return (isinstance(self.categorical_column, FeatureColumn) and
4078            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4079
4080  @property
4081  def name(self):
4082    """See `FeatureColumn` base class."""
4083    return '{}_indicator'.format(self.categorical_column.name)
4084
4085  def _transform_id_weight_pair(self, id_weight_pair, size):
4086    id_tensor = id_weight_pair.id_tensor
4087    weight_tensor = id_weight_pair.weight_tensor
4088
4089    # If the underlying column is weighted, return the input as a dense tensor.
4090    if weight_tensor is not None:
4091      weighted_column = sparse_ops.sparse_merge(
4092          sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(size))
4093      # Remove (?, -1) index.
4094      weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
4095                                                weighted_column.dense_shape)
4096      # Use scatter_nd to merge duplicated indices if existed,
4097      # instead of sparse_tensor_to_dense.
4098      return array_ops.scatter_nd(weighted_column.indices,
4099                                  weighted_column.values,
4100                                  weighted_column.dense_shape)
4101
4102    dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
4103        id_tensor, default_value=-1)
4104
4105    # One hot must be float for tf.concat reasons since all other inputs to
4106    # input_layer are float32.
4107    one_hot_id_tensor = array_ops.one_hot(
4108        dense_id_tensor, depth=size, on_value=1.0, off_value=0.0)
4109
4110    # Reduce to get a multi-hot per example.
4111    return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
4112
4113  def transform_feature(self, transformation_cache, state_manager):
4114    """Returns dense `Tensor` representing feature.
4115
4116    Args:
4117      transformation_cache: A `FeatureTransformationCache` object to access
4118        features.
4119      state_manager: A `StateManager` to create / access resources such as
4120        lookup tables.
4121
4122    Returns:
4123      Transformed feature `Tensor`.
4124
4125    Raises:
4126      ValueError: if input rank is not known at graph building time.
4127    """
4128    id_weight_pair = self.categorical_column.get_sparse_tensors(
4129        transformation_cache, state_manager)
4130    return self._transform_id_weight_pair(id_weight_pair,
4131                                          self.variable_shape[-1])
4132
4133  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4134                          _FEATURE_COLUMN_DEPRECATION)
4135  def _transform_feature(self, inputs):
4136    id_weight_pair = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4137    return self._transform_id_weight_pair(id_weight_pair,
4138                                          self._variable_shape[-1])
4139
4140  @property
4141  def parse_example_spec(self):
4142    """See `FeatureColumn` base class."""
4143    return self.categorical_column.parse_example_spec
4144
4145  @property
4146  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4147                          _FEATURE_COLUMN_DEPRECATION)
4148  def _parse_example_spec(self):
4149    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4150
4151  @property
4152  def variable_shape(self):
4153    """Returns a `TensorShape` representing the shape of the dense `Tensor`."""
4154    if isinstance(self.categorical_column, FeatureColumn):
4155      return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
4156    else:
4157      return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4158
4159  @property
4160  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4161                          _FEATURE_COLUMN_DEPRECATION)
4162  def _variable_shape(self):
4163    return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4164
4165  def get_dense_tensor(self, transformation_cache, state_manager):
4166    """Returns dense `Tensor` representing feature.
4167
4168    Args:
4169      transformation_cache: A `FeatureTransformationCache` object to access
4170        features.
4171      state_manager: A `StateManager` to create / access resources such as
4172        lookup tables.
4173
4174    Returns:
4175      Dense `Tensor` created within `transform_feature`.
4176
4177    Raises:
4178      ValueError: If `categorical_column` is a `SequenceCategoricalColumn`.
4179    """
4180    if isinstance(self.categorical_column, SequenceCategoricalColumn):
4181      raise ValueError(
4182          'In indicator_column: {}. '
4183          'categorical_column must not be of type SequenceCategoricalColumn. '
4184          'Suggested fix A: If you wish to use DenseFeatures, use a '
4185          'non-sequence categorical_column_with_*. '
4186          'Suggested fix B: If you wish to create sequence input, use '
4187          'SequenceFeatures instead of DenseFeatures. '
4188          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4189                                       self.categorical_column))
4190    # Feature has been already transformed. Return the intermediate
4191    # representation created by transform_feature.
4192    return transformation_cache.get(self, state_manager)
4193
4194  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4195                          _FEATURE_COLUMN_DEPRECATION)
4196  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
4197    del weight_collections
4198    del trainable
4199    if isinstance(
4200        self.categorical_column,
4201        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4202      raise ValueError(
4203          'In indicator_column: {}. '
4204          'categorical_column must not be of type _SequenceCategoricalColumn. '
4205          'Suggested fix A: If you wish to use DenseFeatures, use a '
4206          'non-sequence categorical_column_with_*. '
4207          'Suggested fix B: If you wish to create sequence input, use '
4208          'SequenceFeatures instead of DenseFeatures. '
4209          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4210                                       self.categorical_column))
4211    # Feature has been already transformed. Return the intermediate
4212    # representation created by transform_feature.
4213    return inputs.get(self)
4214
4215  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
4216    """See `SequenceDenseColumn` base class."""
4217    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
4218      raise ValueError(
4219          'In indicator_column: {}. '
4220          'categorical_column must be of type SequenceCategoricalColumn '
4221          'to use SequenceFeatures. '
4222          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4223          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4224                                       self.categorical_column))
4225    # Feature has been already transformed. Return the intermediate
4226    # representation created by transform_feature.
4227    dense_tensor = transformation_cache.get(self, state_manager)
4228    sparse_tensors = self.categorical_column.get_sparse_tensors(
4229        transformation_cache, state_manager)
4230    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4231        sparse_tensors.id_tensor)
4232    return SequenceDenseColumn.TensorSequenceLengthPair(
4233        dense_tensor=dense_tensor, sequence_length=sequence_length)
4234
4235  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4236                          _FEATURE_COLUMN_DEPRECATION)
4237  def _get_sequence_dense_tensor(self,
4238                                 inputs,
4239                                 weight_collections=None,
4240                                 trainable=None):
4241    # Do nothing with weight_collections and trainable since no variables are
4242    # created in this function.
4243    del weight_collections
4244    del trainable
4245    if not isinstance(
4246        self.categorical_column,
4247        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4248      raise ValueError(
4249          'In indicator_column: {}. '
4250          'categorical_column must be of type _SequenceCategoricalColumn '
4251          'to use SequenceFeatures. '
4252          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4253          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4254                                       self.categorical_column))
4255    # Feature has been already transformed. Return the intermediate
4256    # representation created by _transform_feature.
4257    dense_tensor = inputs.get(self)
4258    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4259    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4260        sparse_tensors.id_tensor)
4261    return SequenceDenseColumn.TensorSequenceLengthPair(
4262        dense_tensor=dense_tensor, sequence_length=sequence_length)
4263
4264  @property
4265  def parents(self):
4266    """See 'FeatureColumn` base class."""
4267    return [self.categorical_column]
4268
4269  def get_config(self):
4270    """See 'FeatureColumn` base class."""
4271    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4272    config = dict(zip(self._fields, self))
4273    config['categorical_column'] = serialize_feature_column(
4274        self.categorical_column)
4275    return config
4276
4277  @classmethod
4278  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4279    """See 'FeatureColumn` base class."""
4280    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4281    _check_config_keys(config, cls._fields)
4282    kwargs = _standardize_and_copy_config(config)
4283    kwargs['categorical_column'] = deserialize_feature_column(
4284        config['categorical_column'], custom_objects, columns_by_name)
4285    return cls(**kwargs)
4286
4287
4288def _verify_static_batch_size_equality(tensors, columns):
4289  """Verify equality between static batch sizes.
4290
4291  Args:
4292    tensors: iterable of input tensors.
4293    columns: Corresponding feature columns.
4294
4295  Raises:
4296    ValueError: in case of mismatched batch sizes.
4297  """
4298  # bath_size is a Dimension object.
4299  expected_batch_size = None
4300  for i in range(0, len(tensors)):
4301    batch_size = tensor_shape.Dimension(tensor_shape.dimension_value(
4302        tensors[i].shape[0]))
4303    if batch_size.value is not None:
4304      if expected_batch_size is None:
4305        bath_size_column_index = i
4306        expected_batch_size = batch_size
4307      elif not expected_batch_size.is_compatible_with(batch_size):
4308        raise ValueError(
4309            'Batch size (first dimension) of each feature must be same. '
4310            'Batch size of columns ({}, {}): ({}, {})'.format(
4311                columns[bath_size_column_index].name, columns[i].name,
4312                expected_batch_size, batch_size))
4313
4314
4315class SequenceCategoricalColumn(
4316    CategoricalColumn,
4317    fc_old._SequenceCategoricalColumn,  # pylint: disable=protected-access
4318    collections.namedtuple('SequenceCategoricalColumn',
4319                           ('categorical_column'))):
4320  """Represents sequences of categorical data."""
4321
4322  @property
4323  def _is_v2_column(self):
4324    return (isinstance(self.categorical_column, FeatureColumn) and
4325            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4326
4327  @property
4328  def name(self):
4329    """See `FeatureColumn` base class."""
4330    return self.categorical_column.name
4331
4332  @property
4333  def parse_example_spec(self):
4334    """See `FeatureColumn` base class."""
4335    return self.categorical_column.parse_example_spec
4336
4337  @property
4338  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4339                          _FEATURE_COLUMN_DEPRECATION)
4340  def _parse_example_spec(self):
4341    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4342
4343  def transform_feature(self, transformation_cache, state_manager):
4344    """See `FeatureColumn` base class."""
4345    return self.categorical_column.transform_feature(transformation_cache,
4346                                                     state_manager)
4347
4348  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4349                          _FEATURE_COLUMN_DEPRECATION)
4350  def _transform_feature(self, inputs):
4351    return self.categorical_column._transform_feature(inputs)  # pylint: disable=protected-access
4352
4353  @property
4354  def num_buckets(self):
4355    """Returns number of buckets in this sparse feature."""
4356    return self.categorical_column.num_buckets
4357
4358  @property
4359  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4360                          _FEATURE_COLUMN_DEPRECATION)
4361  def _num_buckets(self):
4362    return self.categorical_column._num_buckets  # pylint: disable=protected-access
4363
4364  def _get_sparse_tensors_helper(self, sparse_tensors):
4365    id_tensor = sparse_tensors.id_tensor
4366    weight_tensor = sparse_tensors.weight_tensor
4367    # Expands third dimension, if necessary so that embeddings are not
4368    # combined during embedding lookup. If the tensor is already 3D, leave
4369    # as-is.
4370    shape = array_ops.shape(id_tensor)
4371    # Compute the third dimension explicitly instead of setting it to -1, as
4372    # that doesn't work for dynamically shaped tensors with 0-length at runtime.
4373    # This happens for empty sequences.
4374    target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
4375    id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
4376    if weight_tensor is not None:
4377      weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
4378    return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
4379
4380  def get_sparse_tensors(self, transformation_cache, state_manager):
4381    """Returns an IdWeightPair.
4382
4383    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
4384    weights.
4385
4386    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
4387    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
4388    `SparseTensor` of `float` or `None` to indicate all weights should be
4389    taken to be 1. If specified, `weight_tensor` must have exactly the same
4390    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
4391    output of a `VarLenFeature` which is a ragged matrix.
4392
4393    Args:
4394      transformation_cache: A `FeatureTransformationCache` object to access
4395        features.
4396      state_manager: A `StateManager` to create / access resources such as
4397        lookup tables.
4398    """
4399    sparse_tensors = self.categorical_column.get_sparse_tensors(
4400        transformation_cache, state_manager)
4401    return self._get_sparse_tensors_helper(sparse_tensors)
4402
4403  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4404                          _FEATURE_COLUMN_DEPRECATION)
4405  def _get_sparse_tensors(self, inputs, weight_collections=None,
4406                          trainable=None):
4407    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4408    return self._get_sparse_tensors_helper(sparse_tensors)
4409
4410  @property
4411  def parents(self):
4412    """See 'FeatureColumn` base class."""
4413    return [self.categorical_column]
4414
4415  def get_config(self):
4416    """See 'FeatureColumn` base class."""
4417    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4418    config = dict(zip(self._fields, self))
4419    config['categorical_column'] = serialize_feature_column(
4420        self.categorical_column)
4421    return config
4422
4423  @classmethod
4424  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4425    """See 'FeatureColumn` base class."""
4426    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4427    _check_config_keys(config, cls._fields)
4428    kwargs = _standardize_and_copy_config(config)
4429    kwargs['categorical_column'] = deserialize_feature_column(
4430        config['categorical_column'], custom_objects, columns_by_name)
4431    return cls(**kwargs)
4432
4433
4434def _check_config_keys(config, expected_keys):
4435  """Checks that a config has all expected_keys."""
4436  if set(config.keys()) != set(expected_keys):
4437    raise ValueError('Invalid config: {}, expected keys: {}'.format(
4438        config, expected_keys))
4439
4440
4441def _standardize_and_copy_config(config):
4442  """Returns a shallow copy of config with lists turned to tuples.
4443
4444  Keras serialization uses nest to listify everything.
4445  This causes problems with the NumericColumn shape, which becomes
4446  unhashable. We could try to solve this on the Keras side, but that
4447  would require lots of tracking to avoid changing existing behavior.
4448  Instead, we ensure here that we revive correctly.
4449
4450  Args:
4451    config: dict that will be used to revive a Feature Column
4452
4453  Returns:
4454    Shallow copy of config with lists turned to tuples.
4455  """
4456  kwargs = config.copy()
4457  for k, v in kwargs.items():
4458    if isinstance(v, list):
4459      kwargs[k] = tuple(v)
4460
4461  return kwargs
4462
4463
4464def _sanitize_column_name_for_variable_scope(name):
4465  """Sanitizes user-provided feature names for use as variable scopes."""
4466  invalid_char = re.compile('[^A-Za-z0-9_.\\-]')
4467  return invalid_char.sub('_', name)
4468