• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Embedding layer.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import ops
23from tensorflow.python.keras import backend as K
24from tensorflow.python.keras import constraints
25from tensorflow.python.keras import initializers
26from tensorflow.python.keras import regularizers
27from tensorflow.python.keras.engine.base_layer import Layer
28from tensorflow.python.keras.utils import tf_utils
29from tensorflow.python.ops import embedding_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.util.tf_export import keras_export
32
33
34@keras_export('keras.layers.Embedding')
35class Embedding(Layer):
36  """Turns positive integers (indexes) into dense vectors of fixed size.
37
38  e.g. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]`
39
40  This layer can only be used as the first layer in a model.
41
42  Example:
43
44  ```python
45  model = Sequential()
46  model.add(Embedding(1000, 64, input_length=10))
47  # the model will take as input an integer matrix of size (batch,
48  # input_length).
49  # the largest integer (i.e. word index) in the input should be no larger
50  # than 999 (vocabulary size).
51  # now model.output_shape == (None, 10, 64), where None is the batch
52  # dimension.
53
54  input_array = np.random.randint(1000, size=(32, 10))
55
56  model.compile('rmsprop', 'mse')
57  output_array = model.predict(input_array)
58  assert output_array.shape == (32, 10, 64)
59  ```
60
61  Arguments:
62    input_dim: int > 0. Size of the vocabulary,
63      i.e. maximum integer index + 1.
64    output_dim: int >= 0. Dimension of the dense embedding.
65    embeddings_initializer: Initializer for the `embeddings` matrix.
66    embeddings_regularizer: Regularizer function applied to
67      the `embeddings` matrix.
68    embeddings_constraint: Constraint function applied to
69      the `embeddings` matrix.
70    mask_zero: Whether or not the input value 0 is a special "padding"
71      value that should be masked out.
72      This is useful when using recurrent layers
73      which may take variable length input.
74      If this is `True` then all subsequent layers
75      in the model need to support masking or an exception will be raised.
76      If mask_zero is set to True, as a consequence, index 0 cannot be
77      used in the vocabulary (input_dim should equal size of
78      vocabulary + 1).
79    input_length: Length of input sequences, when it is constant.
80      This argument is required if you are going to connect
81      `Flatten` then `Dense` layers upstream
82      (without it, the shape of the dense outputs cannot be computed).
83
84  Input shape:
85    2D tensor with shape: `(batch_size, input_length)`.
86
87  Output shape:
88    3D tensor with shape: `(batch_size, input_length, output_dim)`.
89  """
90
91  def __init__(self,
92               input_dim,
93               output_dim,
94               embeddings_initializer='uniform',
95               embeddings_regularizer=None,
96               activity_regularizer=None,
97               embeddings_constraint=None,
98               mask_zero=False,
99               input_length=None,
100               **kwargs):
101    if 'input_shape' not in kwargs:
102      if input_length:
103        kwargs['input_shape'] = (input_length,)
104      else:
105        kwargs['input_shape'] = (None,)
106    dtype = kwargs.pop('dtype', K.floatx())
107    super(Embedding, self).__init__(dtype=dtype, **kwargs)
108
109    self.input_dim = input_dim
110    self.output_dim = output_dim
111    self.embeddings_initializer = initializers.get(embeddings_initializer)
112    self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
113    self.activity_regularizer = regularizers.get(activity_regularizer)
114    self.embeddings_constraint = constraints.get(embeddings_constraint)
115    self.mask_zero = mask_zero
116    self.supports_masking = mask_zero
117    self.input_length = input_length
118
119  @tf_utils.shape_type_conversion
120  def build(self, input_shape):
121    # Note: most sparse optimizers do not have GPU kernels defined. When
122    # building graphs, the placement algorithm is able to place variables on CPU
123    # since it knows all kernels using the variable only exist on CPU.
124    # When eager execution is enabled, the placement decision has to be made
125    # right now. Checking for the presence of GPUs to avoid complicating the
126    # TPU codepaths which can handle sparse optimizers.
127    if context.executing_eagerly() and context.context().num_gpus():
128      with ops.device('cpu:0'):
129        self.embeddings = self.add_weight(
130            shape=(self.input_dim, self.output_dim),
131            initializer=self.embeddings_initializer,
132            name='embeddings',
133            regularizer=self.embeddings_regularizer,
134            constraint=self.embeddings_constraint)
135    else:
136      self.embeddings = self.add_weight(
137          shape=(self.input_dim, self.output_dim),
138          initializer=self.embeddings_initializer,
139          name='embeddings',
140          regularizer=self.embeddings_regularizer,
141          constraint=self.embeddings_constraint)
142    self.built = True
143
144  def compute_mask(self, inputs, mask=None):
145    if not self.mask_zero:
146      return None
147
148    return math_ops.not_equal(inputs, 0)
149
150  @tf_utils.shape_type_conversion
151  def compute_output_shape(self, input_shape):
152    if self.input_length is None:
153      return input_shape + (self.output_dim,)
154    else:
155      # input_length can be tuple if input is 3D or higher
156      if isinstance(self.input_length, (list, tuple)):
157        in_lens = list(self.input_length)
158      else:
159        in_lens = [self.input_length]
160      if len(in_lens) != len(input_shape) - 1:
161        raise ValueError('"input_length" is %s, '
162                         'but received input has shape %s' % (str(
163                             self.input_length), str(input_shape)))
164      else:
165        for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])):
166          if s1 is not None and s2 is not None and s1 != s2:
167            raise ValueError('"input_length" is %s, '
168                             'but received input has shape %s' % (str(
169                                 self.input_length), str(input_shape)))
170          elif s1 is None:
171            in_lens[i] = s2
172      return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
173
174  def call(self, inputs):
175    dtype = K.dtype(inputs)
176    if dtype != 'int32' and dtype != 'int64':
177      inputs = math_ops.cast(inputs, 'int32')
178    out = embedding_ops.embedding_lookup(self.embeddings, inputs)
179    return out
180
181  def get_config(self):
182    config = {
183        'input_dim': self.input_dim,
184        'output_dim': self.output_dim,
185        'embeddings_initializer':
186            initializers.serialize(self.embeddings_initializer),
187        'embeddings_regularizer':
188            regularizers.serialize(self.embeddings_regularizer),
189        'activity_regularizer':
190            regularizers.serialize(self.activity_regularizer),
191        'embeddings_constraint':
192            constraints.serialize(self.embeddings_constraint),
193        'mask_zero': self.mask_zero,
194        'input_length': self.input_length
195    }
196    base_config = super(Embedding, self).get_config()
197    return dict(list(base_config.items()) + list(config.items()))
198