1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction."""
16
17# This file was originally under tf/python/feature_column, and was moved to
18# Keras package in order to remove the reverse dependency from TF to Keras.
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24
25from tensorflow.python.feature_column import feature_column_v2
26from tensorflow.python.keras.engine.base_layer import Layer
27from tensorflow.python.keras.utils import generic_utils
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import variable_scope
30
31
32class _BaseFeaturesLayer(Layer):
33  """Base class for DenseFeatures and SequenceFeatures.
34
35  Defines common methods and helpers.
36
37  Args:
38    feature_columns: An iterable containing the FeatureColumns to use as
39      inputs to your model.
40    expected_column_type: Expected class for provided feature columns.
41    trainable:  Boolean, whether the layer's variables will be updated via
42      gradient descent during training.
43    name: Name to give to the DenseFeatures.
44    **kwargs: Keyword arguments to construct a layer.
45
46  Raises:
47    ValueError: if an item in `feature_columns` doesn't match
48      `expected_column_type`.
49  """
50
51  def __init__(self,
52               feature_columns,
53               expected_column_type,
54               trainable,
55               name,
56               partitioner=None,
57               **kwargs):
58    super(_BaseFeaturesLayer, self).__init__(
59        name=name, trainable=trainable, **kwargs)
60    self._feature_columns = feature_column_v2._normalize_feature_columns(  # pylint: disable=protected-access
61        feature_columns)
62    self._state_manager = feature_column_v2._StateManagerImpl(  # pylint: disable=protected-access
63        self, self.trainable)
64    self._partitioner = partitioner
65    for column in self._feature_columns:
66      if not isinstance(column, expected_column_type):
67        raise ValueError(
68            'Items of feature_columns must be a {}. '
69            'You can wrap a categorical column with an '
70            'embedding_column or indicator_column. Given: {}'.format(
71                expected_column_type, column))
72
73  def build(self, _):
74    for column in self._feature_columns:
75      with variable_scope.variable_scope(
76          self.name, partitioner=self._partitioner):
77        with variable_scope.variable_scope(
78            feature_column_v2._sanitize_column_name_for_variable_scope(  # pylint: disable=protected-access
79                column.name)):
80          column.create_state(self._state_manager)
81    super(_BaseFeaturesLayer, self).build(None)
82
83  def _output_shape(self, input_shape, num_elements):
84    """Computes expected output shape of the layer or a column's dense tensor.
85
86    Args:
87      input_shape: Tensor or array with batch shape.
88      num_elements: Size of the last dimension of the output.
89
90    Returns:
91      Tuple with output shape.
92    """
93    raise NotImplementedError('Calling an abstract method.')
94
95  def compute_output_shape(self, input_shape):
96    total_elements = 0
97    for column in self._feature_columns:
98      total_elements += column.variable_shape.num_elements()
99    return self._target_shape(input_shape, total_elements)
100
101  def _process_dense_tensor(self, column, tensor):
102    """Reshapes the dense tensor output of a column based on expected shape.
103
104    Args:
105      column: A DenseColumn or SequenceDenseColumn object.
106      tensor: A dense tensor obtained from the same column.
107
108    Returns:
109      Reshaped dense tensor.
110    """
111    num_elements = column.variable_shape.num_elements()
112    target_shape = self._target_shape(array_ops.shape(tensor), num_elements)
113    return array_ops.reshape(tensor, shape=target_shape)
114
115  def _verify_and_concat_tensors(self, output_tensors):
116    """Verifies and concatenates the dense output of several columns."""
117    feature_column_v2._verify_static_batch_size_equality(  # pylint: disable=protected-access
118        output_tensors, self._feature_columns)
119    return array_ops.concat(output_tensors, -1)
120
121  def get_config(self):
122    # Import here to avoid circular imports.
123    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
124    column_configs = serialization.serialize_feature_columns(
125        self._feature_columns)
126    config = {'feature_columns': column_configs}
127    config['partitioner'] = generic_utils.serialize_keras_object(
128        self._partitioner)
129
130    base_config = super(  # pylint: disable=bad-super-call
131        _BaseFeaturesLayer, self).get_config()
132    return dict(list(base_config.items()) + list(config.items()))
133
134  @classmethod
135  def from_config(cls, config, custom_objects=None):
136    # Import here to avoid circular imports.
137    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
138    config_cp = config.copy()
139    config_cp['feature_columns'] = serialization.deserialize_feature_columns(
140        config['feature_columns'], custom_objects=custom_objects)
141    config_cp['partitioner'] = generic_utils.deserialize_keras_object(
142        config['partitioner'], custom_objects)
143
144    return cls(**config_cp)
145