1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Attention layers that can be used in sequence DNN/CNN models.
16
17This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
18Attention is formed by three tensors: Query, Key and Value.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.keras import backend as K
29from tensorflow.python.keras.engine.base_layer import Layer
30from tensorflow.python.keras.utils import control_flow_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import init_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import nn
35from tensorflow.python.util.tf_export import keras_export
36
37
38class BaseDenseAttention(Layer):
39  """Base Attention class for Dense networks.
40
41  This class is suitable for Dense or CNN networks, and not for RNN networks.
42
43  Implementations of attention mechanisms should inherit from this class, and
44  reuse the `apply_attention_scores()` method.
45
46  Args:
47    causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such
48      that position `i` cannot attend to positions `j > i`. This prevents the
49      flow of information from the future towards the past.
50    dropout: Float between 0 and 1. Fraction of the units to drop for the
51      attention scores.
52
53  Call Args:
54
55    inputs: List of the following tensors:
56      * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
57      * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
58      * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not
59        given, will use `value` for both `key` and `value`, which is the
60        most common case.
61    mask: List of the following tensors:
62      * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
63        If given, the output will be zero at the positions where
64        `mask==False`.
65      * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
66        If given, will apply the mask such that values at positions where
67        `mask==False` do not contribute to the result.
68    training: Python boolean indicating whether the layer should behave in
69      training mode (adding dropout) or in inference mode (no dropout).
70    return_attention_scores: bool, it `True`, returns the attention scores
71      (after masking and softmax) as an additional output argument.
72
73  Output:
74
75    Attention outputs of shape `[batch_size, Tq, dim]`.
76    [Optional] Attention scores after masking and softmax with shape
77      `[batch_size, Tq, Tv]`.
78  """
79
80  def __init__(self, causal=False, dropout=0.0,
81               **kwargs):
82    super(BaseDenseAttention, self).__init__(**kwargs)
83    self.causal = causal
84    self.dropout = dropout
85    self.supports_masking = True
86
87  def _calculate_scores(self, query, key):
88    """Calculates attention scores.
89
90    Args:
91      query: Query tensor of shape `[batch_size, Tq, dim]`.
92      key: Key tensor of shape `[batch_size, Tv, dim]`.
93
94    Returns:
95      Tensor of shape `[batch_size, Tq, Tv]`.
96    """
97    return NotImplementedError
98
99  def _apply_scores(self, scores, value, scores_mask=None, training=None):
100    """Applies attention scores to the given value tensor.
101
102    To use this method in your attention layer, follow the steps:
103
104    * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape
105      `[batch_size, Tv]` to calculate the attention `scores`.
106    * Pass `scores` and `value` tensors to this method. The method applies
107      `scores_mask`, calculates `attention_distribution = softmax(scores)`, then
108      returns `matmul(attention_distribution, value).
109    * Apply `query_mask` and return the result.
110
111    Args:
112      scores: Scores float tensor of shape `[batch_size, Tq, Tv]`.
113      value: Value tensor of shape `[batch_size, Tv, dim]`.
114      scores_mask: A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or
115        `[batch_size, Tq, Tv]`. If given, scores at positions where
116        `scores_mask==False` do not contribute to the result. It must contain
117        at least one `True` value in each line along the last dimension.
118      training: Python boolean indicating whether the layer should behave in
119        training mode (adding dropout) or in inference mode (no dropout).
120
121    Returns:
122      Tensor of shape `[batch_size, Tq, dim]`.
123      Attention scores after masking and softmax with shape
124        `[batch_size, Tq, Tv]`.
125    """
126    if scores_mask is not None:
127      padding_mask = math_ops.logical_not(scores_mask)
128      # Bias so padding positions do not contribute to attention distribution.
129      # Note 65504. is the max float16 value.
130      if scores.dtype is dtypes.float16:
131        scores -= 65504. * math_ops.cast(padding_mask, dtype=scores.dtype)
132      else:
133        scores -= 1.e9 * math_ops.cast(padding_mask, dtype=scores.dtype)
134    if training is None:
135      training = K.learning_phase()
136    weights = nn.softmax(scores)
137
138    def dropped_weights():
139      return nn.dropout(weights, rate=self.dropout)
140
141    weights = control_flow_util.smart_cond(training, dropped_weights,
142                                           lambda: array_ops.identity(weights))
143    return math_ops.matmul(weights, value), weights
144
145  # TODO(b/125916026): Consider exposing a __call__ method with named args.
146  def call(self,
147           inputs,
148           mask=None,
149           training=None,
150           return_attention_scores=False):
151    self._validate_call_args(inputs=inputs, mask=mask)
152    q = inputs[0]
153    v = inputs[1]
154    k = inputs[2] if len(inputs) > 2 else v
155    q_mask = mask[0] if mask else None
156    v_mask = mask[1] if mask else None
157    scores = self._calculate_scores(query=q, key=k)
158    if v_mask is not None:
159      # Mask of shape [batch_size, 1, Tv].
160      v_mask = array_ops.expand_dims(v_mask, axis=-2)
161    if self.causal:
162      # Creates a lower triangular mask, so position i cannot attend to
163      # positions j>i. This prevents the flow of information from the future
164      # into the past.
165      scores_shape = array_ops.shape(scores)
166      # causal_mask_shape = [1, Tq, Tv].
167      causal_mask_shape = array_ops.concat(
168          [array_ops.ones_like(scores_shape[:-2]), scores_shape[-2:]],
169          axis=0)
170      causal_mask = _lower_triangular_mask(causal_mask_shape)
171    else:
172      causal_mask = None
173    scores_mask = _merge_masks(v_mask, causal_mask)
174    result, attention_scores = self._apply_scores(
175        scores=scores, value=v, scores_mask=scores_mask, training=training)
176    if q_mask is not None:
177      # Mask of shape [batch_size, Tq, 1].
178      q_mask = array_ops.expand_dims(q_mask, axis=-1)
179      result *= math_ops.cast(q_mask, dtype=result.dtype)
180    if return_attention_scores:
181      return result, attention_scores
182    return result
183
184  def compute_mask(self, inputs, mask=None):
185    self._validate_call_args(inputs=inputs, mask=mask)
186    if mask:
187      q_mask = mask[0]
188      if q_mask is None:
189        return None
190      return ops.convert_to_tensor_v2_with_dispatch(q_mask)
191    return None
192
193  def _validate_call_args(self, inputs, mask):
194    """Validates arguments of the call method."""
195    class_name = self.__class__.__name__
196    if not isinstance(inputs, list):
197      raise ValueError(
198          '{} layer must be called on a list of inputs, namely [query, value] '
199          'or [query, value, key].'.format(class_name))
200    if len(inputs) < 2 or len(inputs) > 3:
201      raise ValueError(
202          '{} layer accepts inputs list of length 2 or 3, '
203          'namely [query, value] or [query, value, key]. '
204          'Given length: {}'.format(class_name, len(inputs)))
205    if mask:
206      if not isinstance(mask, list):
207        raise ValueError(
208            '{} layer mask must be a list, '
209            'namely [query_mask, value_mask].'.format(class_name))
210      if len(mask) < 2 or len(mask) > len(inputs):
211        raise ValueError(
212            '{} layer mask must be a list of length 2, namely [query_mask, '
213            'value_mask]. Given length: {}'.format(class_name, len(mask)))
214
215  def get_config(self):
216    config = {
217        'causal': self.causal,
218        'dropout': self.dropout,
219    }
220    base_config = super(BaseDenseAttention, self).get_config()
221    return dict(list(base_config.items()) + list(config.items()))
222
223
224@keras_export('keras.layers.Attention')
225class Attention(BaseDenseAttention):
226  """Dot-product attention layer, a.k.a. Luong-style attention.
227
228  Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
229  shape `[batch_size, Tv, dim]` and `key` tensor of shape
230  `[batch_size, Tv, dim]`. The calculation follows the steps:
231
232  1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot
233     product: `scores = tf.matmul(query, key, transpose_b=True)`.
234  2. Use scores to calculate a distribution with shape
235     `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
236  3. Use `distribution` to create a linear combination of `value` with
237     shape `[batch_size, Tq, dim]`:
238     `return tf.matmul(distribution, value)`.
239
240  Args:
241    use_scale: If `True`, will create a scalar variable to scale the attention
242      scores.
243    causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such
244      that position `i` cannot attend to positions `j > i`. This prevents the
245      flow of information from the future towards the past.
246    dropout: Float between 0 and 1. Fraction of the units to drop for the
247      attention scores.
248
249  Call Args:
250
251    inputs: List of the following tensors:
252      * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
253      * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
254      * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not
255        given, will use `value` for both `key` and `value`, which is the
256        most common case.
257    mask: List of the following tensors:
258      * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
259        If given, the output will be zero at the positions where
260        `mask==False`.
261      * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
262        If given, will apply the mask such that values at positions where
263        `mask==False` do not contribute to the result.
264    return_attention_scores: bool, it `True`, returns the attention scores
265      (after masking and softmax) as an additional output argument.
266    training: Python boolean indicating whether the layer should behave in
267      training mode (adding dropout) or in inference mode (no dropout).
268
269  Output:
270
271    Attention outputs of shape `[batch_size, Tq, dim]`.
272    [Optional] Attention scores after masking and softmax with shape
273      `[batch_size, Tq, Tv]`.
274
275  The meaning of `query`, `value` and `key` depend on the application. In the
276  case of text similarity, for example, `query` is the sequence embeddings of
277  the first piece of text and `value` is the sequence embeddings of the second
278  piece of text. `key` is usually the same tensor as `value`.
279
280  Here is a code example for using `Attention` in a CNN+Attention network:
281
282  ```python
283  # Variable-length int sequences.
284  query_input = tf.keras.Input(shape=(None,), dtype='int32')
285  value_input = tf.keras.Input(shape=(None,), dtype='int32')
286
287  # Embedding lookup.
288  token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64)
289  # Query embeddings of shape [batch_size, Tq, dimension].
290  query_embeddings = token_embedding(query_input)
291  # Value embeddings of shape [batch_size, Tv, dimension].
292  value_embeddings = token_embedding(value_input)
293
294  # CNN layer.
295  cnn_layer = tf.keras.layers.Conv1D(
296      filters=100,
297      kernel_size=4,
298      # Use 'same' padding so outputs have the same shape as inputs.
299      padding='same')
300  # Query encoding of shape [batch_size, Tq, filters].
301  query_seq_encoding = cnn_layer(query_embeddings)
302  # Value encoding of shape [batch_size, Tv, filters].
303  value_seq_encoding = cnn_layer(value_embeddings)
304
305  # Query-value attention of shape [batch_size, Tq, filters].
306  query_value_attention_seq = tf.keras.layers.Attention()(
307      [query_seq_encoding, value_seq_encoding])
308
309  # Reduce over the sequence axis to produce encodings of shape
310  # [batch_size, filters].
311  query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
312      query_seq_encoding)
313  query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
314      query_value_attention_seq)
315
316  # Concatenate query and document encodings to produce a DNN input layer.
317  input_layer = tf.keras.layers.Concatenate()(
318      [query_encoding, query_value_attention])
319
320  # Add DNN layers, and create Model.
321  # ...
322  ```
323  """
324
325  def __init__(self, use_scale=False, **kwargs):
326    super(Attention, self).__init__(**kwargs)
327    self.use_scale = use_scale
328
329  def build(self, input_shape):
330    """Creates scale variable if use_scale==True."""
331    if self.use_scale:
332      self.scale = self.add_weight(
333          name='scale',
334          shape=(),
335          initializer=init_ops.ones_initializer(),
336          dtype=self.dtype,
337          trainable=True)
338    else:
339      self.scale = None
340    super(Attention, self).build(input_shape)
341
342  def _calculate_scores(self, query, key):
343    """Calculates attention scores as a query-key dot product.
344
345    Args:
346      query: Query tensor of shape `[batch_size, Tq, dim]`.
347      key: Key tensor of shape `[batch_size, Tv, dim]`.
348    Returns:
349      Tensor of shape `[batch_size, Tq, Tv]`.
350    """
351    scores = math_ops.matmul(query, key, transpose_b=True)
352    if self.scale is not None:
353      scores *= self.scale
354    return scores
355
356  def get_config(self):
357    config = {'use_scale': self.use_scale}
358    base_config = super(Attention, self).get_config()
359    return dict(list(base_config.items()) + list(config.items()))
360
361
362@keras_export('keras.layers.AdditiveAttention')
363class AdditiveAttention(BaseDenseAttention):
364  """Additive attention layer, a.k.a. Bahdanau-style attention.
365
366  Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
367  shape `[batch_size, Tv, dim]` and `key` tensor of shape
368  `[batch_size, Tv, dim]`. The calculation follows the steps:
369
370  1. Reshape `query` and `value` into shapes `[batch_size, Tq, 1, dim]`
371     and `[batch_size, 1, Tv, dim]` respectively.
372  2. Calculate scores with shape `[batch_size, Tq, Tv]` as a non-linear
373     sum: `scores = tf.reduce_sum(tf.tanh(query + value), axis=-1)`
374  3. Use scores to calculate a distribution with shape
375     `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
376  4. Use `distribution` to create a linear combination of `value` with
377     shape `[batch_size, Tq, dim]`:
378     `return tf.matmul(distribution, value)`.
379
380  Args:
381    use_scale: If `True`, will create a variable to scale the attention scores.
382    causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such
383      that position `i` cannot attend to positions `j > i`. This prevents the
384      flow of information from the future towards the past.
385    dropout: Float between 0 and 1. Fraction of the units to drop for the
386      attention scores.
387
388  Call Args:
389
390    inputs: List of the following tensors:
391      * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
392      * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
393      * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not
394        given, will use `value` for both `key` and `value`, which is the
395        most common case.
396    mask: List of the following tensors:
397      * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
398        If given, the output will be zero at the positions where
399        `mask==False`.
400      * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
401        If given, will apply the mask such that values at positions where
402        `mask==False` do not contribute to the result.
403    training: Python boolean indicating whether the layer should behave in
404      training mode (adding dropout) or in inference mode (no dropout).
405    return_attention_scores: bool, it `True`, returns the attention scores
406      (after masking and softmax) as an additional output argument.
407
408  Output:
409
410    Attention outputs of shape `[batch_size, Tq, dim]`.
411    [Optional] Attention scores after masking and softmax with shape
412      `[batch_size, Tq, Tv]`.
413
414  The meaning of `query`, `value` and `key` depend on the application. In the
415  case of text similarity, for example, `query` is the sequence embeddings of
416  the first piece of text and `value` is the sequence embeddings of the second
417  piece of text. `key` is usually the same tensor as `value`.
418
419  Here is a code example for using `AdditiveAttention` in a CNN+Attention
420  network:
421
422  ```python
423  # Variable-length int sequences.
424  query_input = tf.keras.Input(shape=(None,), dtype='int32')
425  value_input = tf.keras.Input(shape=(None,), dtype='int32')
426
427  # Embedding lookup.
428  token_embedding = tf.keras.layers.Embedding(max_tokens, dimension)
429  # Query embeddings of shape [batch_size, Tq, dimension].
430  query_embeddings = token_embedding(query_input)
431  # Value embeddings of shape [batch_size, Tv, dimension].
432  value_embeddings = token_embedding(value_input)
433
434  # CNN layer.
435  cnn_layer = tf.keras.layers.Conv1D(
436      filters=100,
437      kernel_size=4,
438      # Use 'same' padding so outputs have the same shape as inputs.
439      padding='same')
440  # Query encoding of shape [batch_size, Tq, filters].
441  query_seq_encoding = cnn_layer(query_embeddings)
442  # Value encoding of shape [batch_size, Tv, filters].
443  value_seq_encoding = cnn_layer(value_embeddings)
444
445  # Query-value attention of shape [batch_size, Tq, filters].
446  query_value_attention_seq = tf.keras.layers.AdditiveAttention()(
447      [query_seq_encoding, value_seq_encoding])
448
449  # Reduce over the sequence axis to produce encodings of shape
450  # [batch_size, filters].
451  query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
452      query_seq_encoding)
453  query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
454      query_value_attention_seq)
455
456  # Concatenate query and document encodings to produce a DNN input layer.
457  input_layer = tf.keras.layers.Concatenate()(
458      [query_encoding, query_value_attention])
459
460  # Add DNN layers, and create Model.
461  # ...
462  ```
463  """
464
465  def __init__(self, use_scale=True, **kwargs):
466    super(AdditiveAttention, self).__init__(**kwargs)
467    self.use_scale = use_scale
468
469  def build(self, input_shape):
470    v_shape = tensor_shape.TensorShape(input_shape[1])
471    dim = v_shape[-1]
472    if isinstance(dim, tensor_shape.Dimension):
473      dim = dim.value
474    if self.use_scale:
475      self.scale = self.add_weight(
476          name='scale',
477          shape=[dim],
478          initializer=init_ops.glorot_uniform_initializer(),
479          dtype=self.dtype,
480          trainable=True)
481    else:
482      self.scale = None
483    super(AdditiveAttention, self).build(input_shape)
484
485  def _calculate_scores(self, query, key):
486    """Calculates attention scores as a nonlinear sum of query and key.
487
488    Args:
489      query: Query tensor of shape `[batch_size, Tq, dim]`.
490      key: Key tensor of shape `[batch_size, Tv, dim]`.
491    Returns:
492      Tensor of shape `[batch_size, Tq, Tv]`.
493    """
494    # Reshape tensors to enable broadcasting.
495    # Reshape into [batch_size, Tq, 1, dim].
496    q_reshaped = array_ops.expand_dims(query, axis=-2)
497    # Reshape into [batch_size, 1, Tv, dim].
498    k_reshaped = array_ops.expand_dims(key, axis=-3)
499    if self.use_scale:
500      scale = self.scale
501    else:
502      scale = 1.
503    return math_ops.reduce_sum(
504        scale * math_ops.tanh(q_reshaped + k_reshaped), axis=-1)
505
506  def get_config(self):
507    config = {'use_scale': self.use_scale}
508    base_config = super(AdditiveAttention, self).get_config()
509    return dict(list(base_config.items()) + list(config.items()))
510
511
512def _lower_triangular_mask(shape):
513  """Creates a lower-triangular boolean mask over the last 2 dimensions."""
514  row_index = math_ops.cumsum(
515      array_ops.ones(shape=shape, dtype=dtypes.int32), axis=-2)
516  col_index = math_ops.cumsum(
517      array_ops.ones(shape=shape, dtype=dtypes.int32), axis=-1)
518  return math_ops.greater_equal(row_index, col_index)
519
520
521def _merge_masks(x, y):
522  if x is None:
523    return y
524  if y is None:
525    return x
526  return math_ops.logical_and(x, y)
527