1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Embedding layer. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import ops 23from tensorflow.python.keras import backend as K 24from tensorflow.python.keras import constraints 25from tensorflow.python.keras import initializers 26from tensorflow.python.keras import regularizers 27from tensorflow.python.keras.engine.base_layer import Layer 28from tensorflow.python.keras.utils import tf_utils 29from tensorflow.python.ops import embedding_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.util.tf_export import keras_export 32 33 34@keras_export('keras.layers.Embedding') 35class Embedding(Layer): 36 """Turns positive integers (indexes) into dense vectors of fixed size. 37 38 e.g. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]` 39 40 This layer can only be used as the first layer in a model. 41 42 Example: 43 44 ```python 45 model = Sequential() 46 model.add(Embedding(1000, 64, input_length=10)) 47 # the model will take as input an integer matrix of size (batch, 48 # input_length). 49 # the largest integer (i.e. word index) in the input should be no larger 50 # than 999 (vocabulary size). 51 # now model.output_shape == (None, 10, 64), where None is the batch 52 # dimension. 53 54 input_array = np.random.randint(1000, size=(32, 10)) 55 56 model.compile('rmsprop', 'mse') 57 output_array = model.predict(input_array) 58 assert output_array.shape == (32, 10, 64) 59 ``` 60 61 Arguments: 62 input_dim: int > 0. Size of the vocabulary, 63 i.e. maximum integer index + 1. 64 output_dim: int >= 0. Dimension of the dense embedding. 65 embeddings_initializer: Initializer for the `embeddings` matrix. 66 embeddings_regularizer: Regularizer function applied to 67 the `embeddings` matrix. 68 embeddings_constraint: Constraint function applied to 69 the `embeddings` matrix. 70 mask_zero: Whether or not the input value 0 is a special "padding" 71 value that should be masked out. 72 This is useful when using recurrent layers 73 which may take variable length input. 74 If this is `True` then all subsequent layers 75 in the model need to support masking or an exception will be raised. 76 If mask_zero is set to True, as a consequence, index 0 cannot be 77 used in the vocabulary (input_dim should equal size of 78 vocabulary + 1). 79 input_length: Length of input sequences, when it is constant. 80 This argument is required if you are going to connect 81 `Flatten` then `Dense` layers upstream 82 (without it, the shape of the dense outputs cannot be computed). 83 84 Input shape: 85 2D tensor with shape: `(batch_size, input_length)`. 86 87 Output shape: 88 3D tensor with shape: `(batch_size, input_length, output_dim)`. 89 """ 90 91 def __init__(self, 92 input_dim, 93 output_dim, 94 embeddings_initializer='uniform', 95 embeddings_regularizer=None, 96 activity_regularizer=None, 97 embeddings_constraint=None, 98 mask_zero=False, 99 input_length=None, 100 **kwargs): 101 if 'input_shape' not in kwargs: 102 if input_length: 103 kwargs['input_shape'] = (input_length,) 104 else: 105 kwargs['input_shape'] = (None,) 106 dtype = kwargs.pop('dtype', K.floatx()) 107 super(Embedding, self).__init__(dtype=dtype, **kwargs) 108 109 self.input_dim = input_dim 110 self.output_dim = output_dim 111 self.embeddings_initializer = initializers.get(embeddings_initializer) 112 self.embeddings_regularizer = regularizers.get(embeddings_regularizer) 113 self.activity_regularizer = regularizers.get(activity_regularizer) 114 self.embeddings_constraint = constraints.get(embeddings_constraint) 115 self.mask_zero = mask_zero 116 self.supports_masking = mask_zero 117 self.input_length = input_length 118 119 @tf_utils.shape_type_conversion 120 def build(self, input_shape): 121 # Note: most sparse optimizers do not have GPU kernels defined. When 122 # building graphs, the placement algorithm is able to place variables on CPU 123 # since it knows all kernels using the variable only exist on CPU. 124 # When eager execution is enabled, the placement decision has to be made 125 # right now. Checking for the presence of GPUs to avoid complicating the 126 # TPU codepaths which can handle sparse optimizers. 127 if context.executing_eagerly() and context.context().num_gpus(): 128 with ops.device('cpu:0'): 129 self.embeddings = self.add_weight( 130 shape=(self.input_dim, self.output_dim), 131 initializer=self.embeddings_initializer, 132 name='embeddings', 133 regularizer=self.embeddings_regularizer, 134 constraint=self.embeddings_constraint) 135 else: 136 self.embeddings = self.add_weight( 137 shape=(self.input_dim, self.output_dim), 138 initializer=self.embeddings_initializer, 139 name='embeddings', 140 regularizer=self.embeddings_regularizer, 141 constraint=self.embeddings_constraint) 142 self.built = True 143 144 def compute_mask(self, inputs, mask=None): 145 if not self.mask_zero: 146 return None 147 148 return math_ops.not_equal(inputs, 0) 149 150 @tf_utils.shape_type_conversion 151 def compute_output_shape(self, input_shape): 152 if self.input_length is None: 153 return input_shape + (self.output_dim,) 154 else: 155 # input_length can be tuple if input is 3D or higher 156 if isinstance(self.input_length, (list, tuple)): 157 in_lens = list(self.input_length) 158 else: 159 in_lens = [self.input_length] 160 if len(in_lens) != len(input_shape) - 1: 161 raise ValueError('"input_length" is %s, ' 162 'but received input has shape %s' % (str( 163 self.input_length), str(input_shape))) 164 else: 165 for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])): 166 if s1 is not None and s2 is not None and s1 != s2: 167 raise ValueError('"input_length" is %s, ' 168 'but received input has shape %s' % (str( 169 self.input_length), str(input_shape))) 170 elif s1 is None: 171 in_lens[i] = s2 172 return (input_shape[0],) + tuple(in_lens) + (self.output_dim,) 173 174 def call(self, inputs): 175 dtype = K.dtype(inputs) 176 if dtype != 'int32' and dtype != 'int64': 177 inputs = math_ops.cast(inputs, 'int32') 178 out = embedding_ops.embedding_lookup(self.embeddings, inputs) 179 return out 180 181 def get_config(self): 182 config = { 183 'input_dim': self.input_dim, 184 'output_dim': self.output_dim, 185 'embeddings_initializer': 186 initializers.serialize(self.embeddings_initializer), 187 'embeddings_regularizer': 188 regularizers.serialize(self.embeddings_regularizer), 189 'activity_regularizer': 190 regularizers.serialize(self.activity_regularizer), 191 'embeddings_constraint': 192 constraints.serialize(self.embeddings_constraint), 193 'mask_zero': self.mask_zero, 194 'input_length': self.input_length 195 } 196 base_config = super(Embedding, self).get_config() 197 return dict(list(base_config.items()) + list(config.items())) 198