1# Lint as: python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Keras-based attention layer."""
17# pylint: disable=g-classes-have-attributes
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import math
24import string
25
26import numpy as np
27
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.keras import constraints
30from tensorflow.python.keras import initializers
31from tensorflow.python.keras import regularizers
32from tensorflow.python.keras.engine.base_layer import Layer
33from tensorflow.python.keras.layers import advanced_activations
34from tensorflow.python.keras.layers import core
35from tensorflow.python.keras.layers import einsum_dense
36from tensorflow.python.keras.utils import tf_utils
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import special_math_ops
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.util.tf_export import keras_export
42
43
44_CHR_IDX = string.ascii_lowercase
45
46
47def _build_attention_equation(rank, attn_axes):
48  """Builds einsum equations for the attention computation.
49
50  Query, key, value inputs after projection are expected to have the shape as:
51  (bs, <non-attention dims>, <attention dims>, num_heads, channels).
52  bs and <non-attention dims> are treated as <batch dims>.
53  The attention operations can be generalized:
54  (1) Query-key dot product:
55  (<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
56  <key attention dims>, num_heads, channels) -> (<batch dims>,
57  num_heads, <query attention dims>, <key attention dims>)
58  (2) Combination:
59  (<batch dims>, num_heads, <query attention dims>, <key attention dims>),
60  (<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
61  <query attention dims>, num_heads, channels)
62
63  Args:
64    rank: the rank of query, key, value tensors.
65    attn_axes: a list/tuple of axes, [-1, rank), that will do attention.
66
67  Returns:
68    Einsum equations.
69  """
70  target_notation = _CHR_IDX[:rank]
71  # `batch_dims` includes the head dim.
72  batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
73  letter_offset = rank
74  source_notation = ""
75  for i in range(rank):
76    if i in batch_dims or i == rank - 1:
77      source_notation += target_notation[i]
78    else:
79      source_notation += _CHR_IDX[letter_offset]
80      letter_offset += 1
81
82  product_notation = "".join([target_notation[i] for i in batch_dims] +
83                             [target_notation[i] for i in attn_axes] +
84                             [source_notation[i] for i in attn_axes])
85  dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
86                                        product_notation)
87  attn_scores_rank = len(product_notation)
88  combine_equation = "%s,%s->%s" % (product_notation, source_notation,
89                                    target_notation)
90  return dot_product_equation, combine_equation, attn_scores_rank
91
92
93def _build_proj_equation(free_dims, bound_dims, output_dims):
94  """Builds an einsum equation for projections inside multi-head attention."""
95  input_str = ""
96  kernel_str = ""
97  output_str = ""
98  bias_axes = ""
99  letter_offset = 0
100  for i in range(free_dims):
101    char = _CHR_IDX[i + letter_offset]
102    input_str += char
103    output_str += char
104
105  letter_offset += free_dims
106  for i in range(bound_dims):
107    char = _CHR_IDX[i + letter_offset]
108    input_str += char
109    kernel_str += char
110
111  letter_offset += bound_dims
112  for i in range(output_dims):
113    char = _CHR_IDX[i + letter_offset]
114    kernel_str += char
115    output_str += char
116    bias_axes += char
117  equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
118
119  return equation, bias_axes, len(output_str)
120
121
122def _get_output_shape(output_rank, known_last_dims):
123  return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
124
125
126@keras_export("keras.layers.MultiHeadAttention")
127class MultiHeadAttention(Layer):
128  """MultiHeadAttention layer.
129
130  This is an implementation of multi-headed attention based on "Attention
131  is all you Need". If `query`, `key,` `value` are the same, then
132  this is self-attention. Each timestep in `query` attends to the
133  corresponding sequence in `key`, and returns a fixed-width vector.
134
135  This layer first projects `query`, `key` and `value`. These are
136  (effectively) a list of tensors of length `num_attention_heads`, where the
137  corresponding shapes are [batch_size, <query dimensions>, key_dim],
138  [batch_size, <key/value dimensions>, key_dim],
139  [batch_size, <key/value dimensions>, value_dim].
140
141  Then, the query and key tensors are dot-producted and scaled. These are
142  softmaxed to obtain attention probabilities. The value tensors are then
143  interpolated by these probabilities, then concatenated back to a single
144  tensor.
145
146  Finally, the result tensor with the last dimension as value_dim can take an
147  linear projection and return.
148
149  Examples:
150
151  Performs 1D cross-attention over two sequence inputs with an attention mask.
152  Returns the additional attention weights over heads.
153
154  >>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
155  >>> target = tf.keras.Input(shape=[8, 16])
156  >>> source = tf.keras.Input(shape=[4, 16])
157  >>> output_tensor, weights = layer(target, source,
158  ...                                return_attention_scores=True)
159  >>> print(output_tensor.shape)
160  (None, 8, 16)
161  >>> print(weights.shape)
162  (None, 2, 8, 4)
163
164  Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
165
166  >>> layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3))
167  >>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
168  >>> output_tensor = layer(input_tensor, input_tensor)
169  >>> print(output_tensor.shape)
170  (None, 5, 3, 4, 16)
171
172  Args:
173    num_heads: Number of attention heads.
174    key_dim: Size of each attention head for query and key.
175    value_dim:  Size of each attention head for value.
176    dropout: Dropout probability.
177    use_bias: Boolean, whether the dense layers use bias vectors/matrices.
178    output_shape: The expected shape of an output tensor, besides the batch and
179      sequence dims. If not specified, projects back to the key feature dim.
180    attention_axes: axes over which the attention is applied. `None` means
181      attention over all axes, but batch, heads, and features.
182    kernel_initializer: Initializer for dense layer kernels.
183    bias_initializer: Initializer for dense layer biases.
184    kernel_regularizer: Regularizer for dense layer kernels.
185    bias_regularizer: Regularizer for dense layer biases.
186    activity_regularizer: Regularizer for dense layer activity.
187    kernel_constraint: Constraint for dense layer kernels.
188    bias_constraint: Constraint for dense layer kernels.
189
190  Call arguments:
191    query: Query `Tensor` of shape `[B, T, dim]`.
192    value: Value `Tensor` of shape `[B, S, dim]`.
193    key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
194      `value` for both `key` and `value`, which is the most common case.
195    attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
196      attention to certain positions. The boolean mask specifies which query
197      elements can attend to which key elements, 1 indicates attention and 0
198      indicates no attention. Broadcasting can happen for the missing batch
199      dimensions and the head dimension.
200    return_attention_scores: A boolean to indicate whether the output should
201      be attention output if True, or (attention_output, attention_scores) if
202      False. Defaults to False.
203    training: Python boolean indicating whether the layer should behave in
204      training mode (adding dropout) or in inference mode (no dropout).
205      Defaults to either using the training mode of the parent layer/model,
206      or False (inference) if there is no parent layer.
207
208  Returns:
209    attention_output: The result of the computation, of shape [B, T, E],
210      where `T` is for target sequence shapes and `E` is the query input last
211      dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
212      are project to the shape specified by `output_shape`.
213    attention_scores: [Optional] multi-head attention coeffients over
214      attention axes.
215  """
216
217  def __init__(self,
218               num_heads,
219               key_dim,
220               value_dim=None,
221               dropout=0.0,
222               use_bias=True,
223               output_shape=None,
224               attention_axes=None,
225               kernel_initializer="glorot_uniform",
226               bias_initializer="zeros",
227               kernel_regularizer=None,
228               bias_regularizer=None,
229               activity_regularizer=None,
230               kernel_constraint=None,
231               bias_constraint=None,
232               **kwargs):
233    super(MultiHeadAttention, self).__init__(**kwargs)
234    self._num_heads = num_heads
235    self._key_dim = key_dim
236    self._value_dim = value_dim if value_dim else key_dim
237    self._dropout = dropout
238    self._use_bias = use_bias
239    self._output_shape = output_shape
240    self._kernel_initializer = initializers.get(kernel_initializer)
241    self._bias_initializer = initializers.get(bias_initializer)
242    self._kernel_regularizer = regularizers.get(kernel_regularizer)
243    self._bias_regularizer = regularizers.get(bias_regularizer)
244    self._kernel_constraint = constraints.get(kernel_constraint)
245    self._bias_constraint = constraints.get(bias_constraint)
246    if attention_axes is not None and not isinstance(attention_axes,
247                                                     collections.abc.Sized):
248      self._attention_axes = (attention_axes,)
249    else:
250      self._attention_axes = attention_axes
251    self._built_from_signature = False
252    self._query_shape, self._key_shape, self._value_shape = None, None, None
253
254  def get_config(self):
255    config = {
256        "num_heads":
257            self._num_heads,
258        "key_dim":
259            self._key_dim,
260        "value_dim":
261            self._value_dim,
262        "dropout":
263            self._dropout,
264        "use_bias":
265            self._use_bias,
266        "output_shape":
267            self._output_shape,
268        "attention_axes":
269            self._attention_axes,
270        "kernel_initializer":
271            initializers.serialize(self._kernel_initializer),
272        "bias_initializer":
273            initializers.serialize(self._bias_initializer),
274        "kernel_regularizer":
275            regularizers.serialize(self._kernel_regularizer),
276        "bias_regularizer":
277            regularizers.serialize(self._bias_regularizer),
278        "activity_regularizer":
279            regularizers.serialize(self._activity_regularizer),
280        "kernel_constraint":
281            constraints.serialize(self._kernel_constraint),
282        "bias_constraint":
283            constraints.serialize(self._bias_constraint),
284        "query_shape": self._query_shape,
285        "key_shape": self._key_shape,
286        "value_shape": self._value_shape,
287    }
288    base_config = super(MultiHeadAttention, self).get_config()
289    return dict(list(base_config.items()) + list(config.items()))
290
291  @classmethod
292  def from_config(cls, config):
293    # If the layer has a different build() function from the Keras default,
294    # we need to trigger the customized build to create weights.
295    query_shape = config.pop("query_shape")
296    key_shape = config.pop("key_shape")
297    value_shape = config.pop("value_shape")
298    layer = cls(**config)
299    if None in [query_shape, key_shape, value_shape]:
300      logging.warning(
301          "One of the input shape is missing. They should be "
302          "memorized when the layer was serialized. "
303          "%s is created without weights.",
304          str(cls))
305    else:
306      layer._build_from_signature(query_shape, value_shape, key_shape)  # pylint: disable=protected-access
307    return layer
308
309  def _build_from_signature(self, query, value, key=None):
310    """Builds layers and variables.
311
312    Once the method is called, self._built_from_signature will be set to True.
313
314    Args:
315      query: query tensor or TensorShape.
316      value: value tensor or TensorShape.
317      key: key tensor or TensorShape.
318    """
319    self._built_from_signature = True
320    if hasattr(query, "shape"):
321      self._query_shape = tensor_shape.TensorShape(query.shape)
322    else:
323      self._query_shape = tensor_shape.TensorShape(query)
324    if hasattr(value, "shape"):
325      self._value_shape = tensor_shape.TensorShape(value.shape)
326    else:
327      self._value_shape = tensor_shape.TensorShape(value)
328    if key is None:
329      self._key_shape = self._value_shape
330    elif hasattr(key, "shape"):
331      self._key_shape = tensor_shape.TensorShape(key.shape)
332    else:
333      self._key_shape = tensor_shape.TensorShape(key)
334
335    common_kwargs = dict(
336        kernel_initializer=self._kernel_initializer,
337        bias_initializer=self._bias_initializer,
338        kernel_regularizer=self._kernel_regularizer,
339        bias_regularizer=self._bias_regularizer,
340        activity_regularizer=self._activity_regularizer,
341        kernel_constraint=self._kernel_constraint,
342        bias_constraint=self._bias_constraint)
343    # Any setup work performed only once should happen in an `init_scope`
344    # to avoid creating symbolic Tensors that will later pollute any eager
345    # operations.
346    with tf_utils.maybe_init_scope(self):
347      free_dims = self._query_shape.rank - 1
348      einsum_equation, bias_axes, output_rank = _build_proj_equation(
349          free_dims, bound_dims=1, output_dims=2)
350      self._query_dense = einsum_dense.EinsumDense(
351          einsum_equation,
352          output_shape=_get_output_shape(output_rank - 1,
353                                         [self._num_heads, self._key_dim]),
354          bias_axes=bias_axes if self._use_bias else None,
355          name="query",
356          **common_kwargs)
357      einsum_equation, bias_axes, output_rank = _build_proj_equation(
358          self._key_shape.rank - 1, bound_dims=1, output_dims=2)
359      self._key_dense = einsum_dense.EinsumDense(
360          einsum_equation,
361          output_shape=_get_output_shape(output_rank - 1,
362                                         [self._num_heads, self._key_dim]),
363          bias_axes=bias_axes if self._use_bias else None,
364          name="key",
365          **common_kwargs)
366      einsum_equation, bias_axes, output_rank = _build_proj_equation(
367          self._value_shape.rank - 1, bound_dims=1, output_dims=2)
368      self._value_dense = einsum_dense.EinsumDense(
369          einsum_equation,
370          output_shape=_get_output_shape(output_rank - 1,
371                                         [self._num_heads, self._value_dim]),
372          bias_axes=bias_axes if self._use_bias else None,
373          name="value",
374          **common_kwargs)
375
376      # Builds the attention computations for multi-head dot product attention.
377      # These computations could be wrapped into the keras attention layer once
378      # it support mult-head einsum computations.
379      self._build_attention(output_rank)
380      if self._output_shape:
381        if not isinstance(self._output_shape, collections.abc.Sized):
382          output_shape = [self._output_shape]
383        else:
384          output_shape = self._output_shape
385      else:
386        output_shape = [self._query_shape[-1]]
387      einsum_equation, bias_axes, output_rank = _build_proj_equation(
388          free_dims, bound_dims=2, output_dims=len(output_shape))
389      self._output_dense = einsum_dense.EinsumDense(
390          einsum_equation,
391          output_shape=_get_output_shape(output_rank - 1, output_shape),
392          bias_axes=bias_axes if self._use_bias else None,
393          name="attention_output",
394          **common_kwargs)
395
396  def _build_attention(self, rank):
397    """Builds multi-head dot-product attention computations.
398
399    This function builds attributes necessary for `_compute_attention` to
400    costomize attention computation to replace the default dot-product
401    attention.
402
403    Args:
404      rank: the rank of query, key, value tensors.
405    """
406    if self._attention_axes is None:
407      self._attention_axes = tuple(range(1, rank - 2))
408    else:
409      self._attention_axes = tuple(self._attention_axes)
410    self._dot_product_equation, self._combine_equation, attn_scores_rank = (
411        _build_attention_equation(rank, attn_axes=self._attention_axes))
412    norm_axes = tuple(
413        range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
414    self._softmax = advanced_activations.Softmax(axis=norm_axes)
415    self._dropout_layer = core.Dropout(rate=self._dropout)
416
417  def _masked_softmax(self, attention_scores, attention_mask=None):
418    # Normalize the attention scores to probabilities.
419    # `attention_scores` = [B, N, T, S]
420    if attention_mask is not None:
421      # The expand dim happens starting from the `num_heads` dimension,
422      # (<batch_dims>, num_heads, <query_attention_dims, key_attention_dims>)
423      mask_expansion_axes = [-len(self._attention_axes) * 2 - 1]
424      for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
425        attention_mask = array_ops.expand_dims(
426            attention_mask, axis=mask_expansion_axes)
427    return self._softmax(attention_scores, attention_mask)
428
429  def _compute_attention(self,
430                         query,
431                         key,
432                         value,
433                         attention_mask=None,
434                         training=None):
435    """Applies Dot-product attention with query, key, value tensors.
436
437    This function defines the computation inside `call` with projected
438    multi-head Q, K, V inputs. Users can override this function for customized
439    attention implementation.
440
441    Args:
442      query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
443      key: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
444      value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
445      attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
446        attention to certain positions.
447      training: Python boolean indicating whether the layer should behave in
448        training mode (adding dropout) or in inference mode (doing nothing).
449
450    Returns:
451      attention_output: Multi-headed outputs of attention computation.
452      attention_scores: Multi-headed attention weights.
453    """
454    # Note: Applying scalar multiply at the smaller end of einsum improves
455    # XLA performance, but may introduce slight numeric differences in
456    # the Transformer attention head.
457    query = math_ops.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
458
459    # Take the dot product between "query" and "key" to get the raw
460    # attention scores.
461    attention_scores = special_math_ops.einsum(self._dot_product_equation, key,
462                                               query)
463
464    attention_scores = self._masked_softmax(attention_scores, attention_mask)
465
466    # This is actually dropping out entire tokens to attend to, which might
467    # seem a bit unusual, but is taken from the original Transformer paper.
468    attention_scores_dropout = self._dropout_layer(
469        attention_scores, training=training)
470
471    # `context_layer` = [B, T, N, H]
472    attention_output = special_math_ops.einsum(self._combine_equation,
473                                               attention_scores_dropout, value)
474    return attention_output, attention_scores
475
476  def call(self,
477           query,
478           value,
479           key=None,
480           attention_mask=None,
481           return_attention_scores=False,
482           training=None):
483    if not self._built_from_signature:
484      self._build_from_signature(query=query, value=value, key=key)
485    if key is None:
486      key = value
487
488    #   N = `num_attention_heads`
489    #   H = `size_per_head`
490    # `query` = [B, T, N ,H]
491    query = self._query_dense(query)
492
493    # `key` = [B, S, N, H]
494    key = self._key_dense(key)
495
496    # `value` = [B, S, N, H]
497    value = self._value_dense(value)
498
499    attention_output, attention_scores = self._compute_attention(
500        query, key, value, attention_mask, training)
501    attention_output = self._output_dense(attention_output)
502
503    if return_attention_scores:
504      return attention_output, attention_scores
505    return attention_output
506