1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This API defines FeatureColumn for sequential input. 16 17NOTE: This API is a work in progress and will likely be changing frequently. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24from tensorflow.python.feature_column import feature_column_v2 as fc 25from tensorflow.python.framework import ops 26from tensorflow.python.keras import backend 27from tensorflow.python.keras.feature_column import base_feature_layer as kfc 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import check_ops 30from tensorflow.python.util.tf_export import keras_export 31 32# pylint: disable=protected-access 33 34 35@keras_export('keras.experimental.SequenceFeatures') 36class SequenceFeatures(kfc._BaseFeaturesLayer): 37 """A layer for sequence input. 38 39 All `feature_columns` must be sequence dense columns with the same 40 `sequence_length`. The output of this method can be fed into sequence 41 networks, such as RNN. 42 43 The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. 44 `T` is the maximum sequence length for this batch, which could differ from 45 batch to batch. 46 47 If multiple `feature_columns` are given with `Di` `num_elements` each, their 48 outputs are concatenated. So, the final `Tensor` has shape 49 `[batch_size, T, D0 + D1 + ... + Dn]`. 50 51 Example: 52 53 ```python 54 55 import tensorflow as tf 56 57 # Behavior of some cells or feature columns may depend on whether we are in 58 # training or inference mode, e.g. applying dropout. 59 training = True 60 rating = tf.feature_column.sequence_numeric_column('rating') 61 watches = tf.feature_column.sequence_categorical_column_with_identity( 62 'watches', num_buckets=1000) 63 watches_embedding = tf.feature_column.embedding_column(watches, 64 dimension=10) 65 columns = [rating, watches_embedding] 66 67 features = { 68 'rating': tf.sparse.from_dense([[1.0,1.1, 0, 0, 0], 69 [2.0,2.1,2.2, 2.3, 2.5]]), 70 'watches': tf.sparse.from_dense([[2, 85, 0, 0, 0],[33,78, 2, 73, 1]]) 71 } 72 73 sequence_input_layer = tf.keras.experimental.SequenceFeatures(columns) 74 sequence_input, sequence_length = sequence_input_layer( 75 features, training=training) 76 sequence_length_mask = tf.sequence_mask(sequence_length) 77 hidden_size = 32 78 rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) 79 rnn_layer = tf.keras.layers.RNN(rnn_cell) 80 outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) 81 ``` 82 """ 83 84 def __init__( 85 self, 86 feature_columns, 87 trainable=True, 88 name=None, 89 **kwargs): 90 """"Constructs a SequenceFeatures layer. 91 92 Args: 93 feature_columns: An iterable of dense sequence columns. Valid columns are 94 - `embedding_column` that wraps a `sequence_categorical_column_with_*` 95 - `sequence_numeric_column`. 96 trainable: Boolean, whether the layer's variables will be updated via 97 gradient descent during training. 98 name: Name to give to the SequenceFeatures. 99 **kwargs: Keyword arguments to construct a layer. 100 101 Raises: 102 ValueError: If any of the `feature_columns` is not a 103 `SequenceDenseColumn`. 104 """ 105 super(SequenceFeatures, self).__init__( 106 feature_columns=feature_columns, 107 trainable=trainable, 108 name=name, 109 expected_column_type=fc.SequenceDenseColumn, 110 **kwargs) 111 112 @property 113 def _is_feature_layer(self): 114 return True 115 116 def _target_shape(self, input_shape, total_elements): 117 return (input_shape[0], input_shape[1], total_elements) 118 119 def call(self, features, training=None): 120 """Returns sequence input corresponding to the `feature_columns`. 121 122 Args: 123 features: A dict mapping keys to tensors. 124 training: Python boolean or None, indicating whether to the layer is being 125 run in training mode. This argument is passed to the call method of any 126 `FeatureColumn` that takes a `training` argument. For example, if a 127 `FeatureColumn` performed dropout, the column could expose a `training` 128 argument to control whether the dropout should be applied. If `None`, 129 defaults to `tf.keras.backend.learning_phase()`. 130 131 132 Returns: 133 An `(input_layer, sequence_length)` tuple where: 134 - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. 135 `T` is the maximum sequence length for this batch, which could differ 136 from batch to batch. `D` is the sum of `num_elements` for all 137 `feature_columns`. 138 - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence 139 length for each example. 140 141 Raises: 142 ValueError: If features are not a dictionary. 143 """ 144 if not isinstance(features, dict): 145 raise ValueError('We expected a dictionary here. Instead we got: ', 146 features) 147 if training is None: 148 training = backend.learning_phase() 149 transformation_cache = fc.FeatureTransformationCache(features) 150 output_tensors = [] 151 sequence_lengths = [] 152 153 for column in self._feature_columns: 154 with backend.name_scope(column.name): 155 try: 156 dense_tensor, sequence_length = column.get_sequence_dense_tensor( 157 transformation_cache, self._state_manager, training=training) 158 except TypeError: 159 dense_tensor, sequence_length = column.get_sequence_dense_tensor( 160 transformation_cache, self._state_manager) 161 # Flattens the final dimension to produce a 3D Tensor. 162 output_tensors.append(self._process_dense_tensor(column, dense_tensor)) 163 sequence_lengths.append(sequence_length) 164 165 # Check and process sequence lengths. 166 fc._verify_static_batch_size_equality(sequence_lengths, 167 self._feature_columns) 168 sequence_length = _assert_all_equal_and_return(sequence_lengths) 169 170 return self._verify_and_concat_tensors(output_tensors), sequence_length 171 172 173def _assert_all_equal_and_return(tensors, name=None): 174 """Asserts that all tensors are equal and returns the first one.""" 175 with backend.name_scope(name or 'assert_all_equal'): 176 if len(tensors) == 1: 177 return tensors[0] 178 assert_equal_ops = [] 179 for t in tensors[1:]: 180 assert_equal_ops.append(check_ops.assert_equal(tensors[0], t)) 181 with ops.control_dependencies(assert_equal_ops): 182 return array_ops.identity(tensors[0]) 183