1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Attention layers that can be used in sequence DNN/CNN models. 16 17This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2. 18Attention is formed by three tensors: Query, Key and Value. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.keras import backend as K 29from tensorflow.python.keras.engine.base_layer import Layer 30from tensorflow.python.keras.utils import control_flow_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import init_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import nn 35from tensorflow.python.util.tf_export import keras_export 36 37 38class BaseDenseAttention(Layer): 39 """Base Attention class for Dense networks. 40 41 This class is suitable for Dense or CNN networks, and not for RNN networks. 42 43 Implementations of attention mechanisms should inherit from this class, and 44 reuse the `apply_attention_scores()` method. 45 46 Args: 47 causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such 48 that position `i` cannot attend to positions `j > i`. This prevents the 49 flow of information from the future towards the past. 50 dropout: Float between 0 and 1. Fraction of the units to drop for the 51 attention scores. 52 53 Call Args: 54 55 inputs: List of the following tensors: 56 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`. 57 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`. 58 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not 59 given, will use `value` for both `key` and `value`, which is the 60 most common case. 61 mask: List of the following tensors: 62 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`. 63 If given, the output will be zero at the positions where 64 `mask==False`. 65 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`. 66 If given, will apply the mask such that values at positions where 67 `mask==False` do not contribute to the result. 68 training: Python boolean indicating whether the layer should behave in 69 training mode (adding dropout) or in inference mode (no dropout). 70 return_attention_scores: bool, it `True`, returns the attention scores 71 (after masking and softmax) as an additional output argument. 72 73 Output: 74 75 Attention outputs of shape `[batch_size, Tq, dim]`. 76 [Optional] Attention scores after masking and softmax with shape 77 `[batch_size, Tq, Tv]`. 78 """ 79 80 def __init__(self, causal=False, dropout=0.0, 81 **kwargs): 82 super(BaseDenseAttention, self).__init__(**kwargs) 83 self.causal = causal 84 self.dropout = dropout 85 self.supports_masking = True 86 87 def _calculate_scores(self, query, key): 88 """Calculates attention scores. 89 90 Args: 91 query: Query tensor of shape `[batch_size, Tq, dim]`. 92 key: Key tensor of shape `[batch_size, Tv, dim]`. 93 94 Returns: 95 Tensor of shape `[batch_size, Tq, Tv]`. 96 """ 97 return NotImplementedError 98 99 def _apply_scores(self, scores, value, scores_mask=None, training=None): 100 """Applies attention scores to the given value tensor. 101 102 To use this method in your attention layer, follow the steps: 103 104 * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape 105 `[batch_size, Tv]` to calculate the attention `scores`. 106 * Pass `scores` and `value` tensors to this method. The method applies 107 `scores_mask`, calculates `attention_distribution = softmax(scores)`, then 108 returns `matmul(attention_distribution, value). 109 * Apply `query_mask` and return the result. 110 111 Args: 112 scores: Scores float tensor of shape `[batch_size, Tq, Tv]`. 113 value: Value tensor of shape `[batch_size, Tv, dim]`. 114 scores_mask: A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or 115 `[batch_size, Tq, Tv]`. If given, scores at positions where 116 `scores_mask==False` do not contribute to the result. It must contain 117 at least one `True` value in each line along the last dimension. 118 training: Python boolean indicating whether the layer should behave in 119 training mode (adding dropout) or in inference mode (no dropout). 120 121 Returns: 122 Tensor of shape `[batch_size, Tq, dim]`. 123 Attention scores after masking and softmax with shape 124 `[batch_size, Tq, Tv]`. 125 """ 126 if scores_mask is not None: 127 padding_mask = math_ops.logical_not(scores_mask) 128 # Bias so padding positions do not contribute to attention distribution. 129 # Note 65504. is the max float16 value. 130 if scores.dtype is dtypes.float16: 131 scores -= 65504. * math_ops.cast(padding_mask, dtype=scores.dtype) 132 else: 133 scores -= 1.e9 * math_ops.cast(padding_mask, dtype=scores.dtype) 134 if training is None: 135 training = K.learning_phase() 136 weights = nn.softmax(scores) 137 138 def dropped_weights(): 139 return nn.dropout(weights, rate=self.dropout) 140 141 weights = control_flow_util.smart_cond(training, dropped_weights, 142 lambda: array_ops.identity(weights)) 143 return math_ops.matmul(weights, value), weights 144 145 # TODO(b/125916026): Consider exposing a __call__ method with named args. 146 def call(self, 147 inputs, 148 mask=None, 149 training=None, 150 return_attention_scores=False): 151 self._validate_call_args(inputs=inputs, mask=mask) 152 q = inputs[0] 153 v = inputs[1] 154 k = inputs[2] if len(inputs) > 2 else v 155 q_mask = mask[0] if mask else None 156 v_mask = mask[1] if mask else None 157 scores = self._calculate_scores(query=q, key=k) 158 if v_mask is not None: 159 # Mask of shape [batch_size, 1, Tv]. 160 v_mask = array_ops.expand_dims(v_mask, axis=-2) 161 if self.causal: 162 # Creates a lower triangular mask, so position i cannot attend to 163 # positions j>i. This prevents the flow of information from the future 164 # into the past. 165 scores_shape = array_ops.shape(scores) 166 # causal_mask_shape = [1, Tq, Tv]. 167 causal_mask_shape = array_ops.concat( 168 [array_ops.ones_like(scores_shape[:-2]), scores_shape[-2:]], 169 axis=0) 170 causal_mask = _lower_triangular_mask(causal_mask_shape) 171 else: 172 causal_mask = None 173 scores_mask = _merge_masks(v_mask, causal_mask) 174 result, attention_scores = self._apply_scores( 175 scores=scores, value=v, scores_mask=scores_mask, training=training) 176 if q_mask is not None: 177 # Mask of shape [batch_size, Tq, 1]. 178 q_mask = array_ops.expand_dims(q_mask, axis=-1) 179 result *= math_ops.cast(q_mask, dtype=result.dtype) 180 if return_attention_scores: 181 return result, attention_scores 182 return result 183 184 def compute_mask(self, inputs, mask=None): 185 self._validate_call_args(inputs=inputs, mask=mask) 186 if mask: 187 q_mask = mask[0] 188 if q_mask is None: 189 return None 190 return ops.convert_to_tensor_v2_with_dispatch(q_mask) 191 return None 192 193 def _validate_call_args(self, inputs, mask): 194 """Validates arguments of the call method.""" 195 class_name = self.__class__.__name__ 196 if not isinstance(inputs, list): 197 raise ValueError( 198 '{} layer must be called on a list of inputs, namely [query, value] ' 199 'or [query, value, key].'.format(class_name)) 200 if len(inputs) < 2 or len(inputs) > 3: 201 raise ValueError( 202 '{} layer accepts inputs list of length 2 or 3, ' 203 'namely [query, value] or [query, value, key]. ' 204 'Given length: {}'.format(class_name, len(inputs))) 205 if mask: 206 if not isinstance(mask, list): 207 raise ValueError( 208 '{} layer mask must be a list, ' 209 'namely [query_mask, value_mask].'.format(class_name)) 210 if len(mask) < 2 or len(mask) > len(inputs): 211 raise ValueError( 212 '{} layer mask must be a list of length 2, namely [query_mask, ' 213 'value_mask]. Given length: {}'.format(class_name, len(mask))) 214 215 def get_config(self): 216 config = { 217 'causal': self.causal, 218 'dropout': self.dropout, 219 } 220 base_config = super(BaseDenseAttention, self).get_config() 221 return dict(list(base_config.items()) + list(config.items())) 222 223 224@keras_export('keras.layers.Attention') 225class Attention(BaseDenseAttention): 226 """Dot-product attention layer, a.k.a. Luong-style attention. 227 228 Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of 229 shape `[batch_size, Tv, dim]` and `key` tensor of shape 230 `[batch_size, Tv, dim]`. The calculation follows the steps: 231 232 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot 233 product: `scores = tf.matmul(query, key, transpose_b=True)`. 234 2. Use scores to calculate a distribution with shape 235 `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`. 236 3. Use `distribution` to create a linear combination of `value` with 237 shape `[batch_size, Tq, dim]`: 238 `return tf.matmul(distribution, value)`. 239 240 Args: 241 use_scale: If `True`, will create a scalar variable to scale the attention 242 scores. 243 causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such 244 that position `i` cannot attend to positions `j > i`. This prevents the 245 flow of information from the future towards the past. 246 dropout: Float between 0 and 1. Fraction of the units to drop for the 247 attention scores. 248 249 Call Args: 250 251 inputs: List of the following tensors: 252 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`. 253 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`. 254 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not 255 given, will use `value` for both `key` and `value`, which is the 256 most common case. 257 mask: List of the following tensors: 258 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`. 259 If given, the output will be zero at the positions where 260 `mask==False`. 261 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`. 262 If given, will apply the mask such that values at positions where 263 `mask==False` do not contribute to the result. 264 return_attention_scores: bool, it `True`, returns the attention scores 265 (after masking and softmax) as an additional output argument. 266 training: Python boolean indicating whether the layer should behave in 267 training mode (adding dropout) or in inference mode (no dropout). 268 269 Output: 270 271 Attention outputs of shape `[batch_size, Tq, dim]`. 272 [Optional] Attention scores after masking and softmax with shape 273 `[batch_size, Tq, Tv]`. 274 275 The meaning of `query`, `value` and `key` depend on the application. In the 276 case of text similarity, for example, `query` is the sequence embeddings of 277 the first piece of text and `value` is the sequence embeddings of the second 278 piece of text. `key` is usually the same tensor as `value`. 279 280 Here is a code example for using `Attention` in a CNN+Attention network: 281 282 ```python 283 # Variable-length int sequences. 284 query_input = tf.keras.Input(shape=(None,), dtype='int32') 285 value_input = tf.keras.Input(shape=(None,), dtype='int32') 286 287 # Embedding lookup. 288 token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64) 289 # Query embeddings of shape [batch_size, Tq, dimension]. 290 query_embeddings = token_embedding(query_input) 291 # Value embeddings of shape [batch_size, Tv, dimension]. 292 value_embeddings = token_embedding(value_input) 293 294 # CNN layer. 295 cnn_layer = tf.keras.layers.Conv1D( 296 filters=100, 297 kernel_size=4, 298 # Use 'same' padding so outputs have the same shape as inputs. 299 padding='same') 300 # Query encoding of shape [batch_size, Tq, filters]. 301 query_seq_encoding = cnn_layer(query_embeddings) 302 # Value encoding of shape [batch_size, Tv, filters]. 303 value_seq_encoding = cnn_layer(value_embeddings) 304 305 # Query-value attention of shape [batch_size, Tq, filters]. 306 query_value_attention_seq = tf.keras.layers.Attention()( 307 [query_seq_encoding, value_seq_encoding]) 308 309 # Reduce over the sequence axis to produce encodings of shape 310 # [batch_size, filters]. 311 query_encoding = tf.keras.layers.GlobalAveragePooling1D()( 312 query_seq_encoding) 313 query_value_attention = tf.keras.layers.GlobalAveragePooling1D()( 314 query_value_attention_seq) 315 316 # Concatenate query and document encodings to produce a DNN input layer. 317 input_layer = tf.keras.layers.Concatenate()( 318 [query_encoding, query_value_attention]) 319 320 # Add DNN layers, and create Model. 321 # ... 322 ``` 323 """ 324 325 def __init__(self, use_scale=False, **kwargs): 326 super(Attention, self).__init__(**kwargs) 327 self.use_scale = use_scale 328 329 def build(self, input_shape): 330 """Creates scale variable if use_scale==True.""" 331 if self.use_scale: 332 self.scale = self.add_weight( 333 name='scale', 334 shape=(), 335 initializer=init_ops.ones_initializer(), 336 dtype=self.dtype, 337 trainable=True) 338 else: 339 self.scale = None 340 super(Attention, self).build(input_shape) 341 342 def _calculate_scores(self, query, key): 343 """Calculates attention scores as a query-key dot product. 344 345 Args: 346 query: Query tensor of shape `[batch_size, Tq, dim]`. 347 key: Key tensor of shape `[batch_size, Tv, dim]`. 348 Returns: 349 Tensor of shape `[batch_size, Tq, Tv]`. 350 """ 351 scores = math_ops.matmul(query, key, transpose_b=True) 352 if self.scale is not None: 353 scores *= self.scale 354 return scores 355 356 def get_config(self): 357 config = {'use_scale': self.use_scale} 358 base_config = super(Attention, self).get_config() 359 return dict(list(base_config.items()) + list(config.items())) 360 361 362@keras_export('keras.layers.AdditiveAttention') 363class AdditiveAttention(BaseDenseAttention): 364 """Additive attention layer, a.k.a. Bahdanau-style attention. 365 366 Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of 367 shape `[batch_size, Tv, dim]` and `key` tensor of shape 368 `[batch_size, Tv, dim]`. The calculation follows the steps: 369 370 1. Reshape `query` and `value` into shapes `[batch_size, Tq, 1, dim]` 371 and `[batch_size, 1, Tv, dim]` respectively. 372 2. Calculate scores with shape `[batch_size, Tq, Tv]` as a non-linear 373 sum: `scores = tf.reduce_sum(tf.tanh(query + value), axis=-1)` 374 3. Use scores to calculate a distribution with shape 375 `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`. 376 4. Use `distribution` to create a linear combination of `value` with 377 shape `[batch_size, Tq, dim]`: 378 `return tf.matmul(distribution, value)`. 379 380 Args: 381 use_scale: If `True`, will create a variable to scale the attention scores. 382 causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such 383 that position `i` cannot attend to positions `j > i`. This prevents the 384 flow of information from the future towards the past. 385 dropout: Float between 0 and 1. Fraction of the units to drop for the 386 attention scores. 387 388 Call Args: 389 390 inputs: List of the following tensors: 391 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`. 392 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`. 393 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not 394 given, will use `value` for both `key` and `value`, which is the 395 most common case. 396 mask: List of the following tensors: 397 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`. 398 If given, the output will be zero at the positions where 399 `mask==False`. 400 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`. 401 If given, will apply the mask such that values at positions where 402 `mask==False` do not contribute to the result. 403 training: Python boolean indicating whether the layer should behave in 404 training mode (adding dropout) or in inference mode (no dropout). 405 return_attention_scores: bool, it `True`, returns the attention scores 406 (after masking and softmax) as an additional output argument. 407 408 Output: 409 410 Attention outputs of shape `[batch_size, Tq, dim]`. 411 [Optional] Attention scores after masking and softmax with shape 412 `[batch_size, Tq, Tv]`. 413 414 The meaning of `query`, `value` and `key` depend on the application. In the 415 case of text similarity, for example, `query` is the sequence embeddings of 416 the first piece of text and `value` is the sequence embeddings of the second 417 piece of text. `key` is usually the same tensor as `value`. 418 419 Here is a code example for using `AdditiveAttention` in a CNN+Attention 420 network: 421 422 ```python 423 # Variable-length int sequences. 424 query_input = tf.keras.Input(shape=(None,), dtype='int32') 425 value_input = tf.keras.Input(shape=(None,), dtype='int32') 426 427 # Embedding lookup. 428 token_embedding = tf.keras.layers.Embedding(max_tokens, dimension) 429 # Query embeddings of shape [batch_size, Tq, dimension]. 430 query_embeddings = token_embedding(query_input) 431 # Value embeddings of shape [batch_size, Tv, dimension]. 432 value_embeddings = token_embedding(value_input) 433 434 # CNN layer. 435 cnn_layer = tf.keras.layers.Conv1D( 436 filters=100, 437 kernel_size=4, 438 # Use 'same' padding so outputs have the same shape as inputs. 439 padding='same') 440 # Query encoding of shape [batch_size, Tq, filters]. 441 query_seq_encoding = cnn_layer(query_embeddings) 442 # Value encoding of shape [batch_size, Tv, filters]. 443 value_seq_encoding = cnn_layer(value_embeddings) 444 445 # Query-value attention of shape [batch_size, Tq, filters]. 446 query_value_attention_seq = tf.keras.layers.AdditiveAttention()( 447 [query_seq_encoding, value_seq_encoding]) 448 449 # Reduce over the sequence axis to produce encodings of shape 450 # [batch_size, filters]. 451 query_encoding = tf.keras.layers.GlobalAveragePooling1D()( 452 query_seq_encoding) 453 query_value_attention = tf.keras.layers.GlobalAveragePooling1D()( 454 query_value_attention_seq) 455 456 # Concatenate query and document encodings to produce a DNN input layer. 457 input_layer = tf.keras.layers.Concatenate()( 458 [query_encoding, query_value_attention]) 459 460 # Add DNN layers, and create Model. 461 # ... 462 ``` 463 """ 464 465 def __init__(self, use_scale=True, **kwargs): 466 super(AdditiveAttention, self).__init__(**kwargs) 467 self.use_scale = use_scale 468 469 def build(self, input_shape): 470 v_shape = tensor_shape.TensorShape(input_shape[1]) 471 dim = v_shape[-1] 472 if isinstance(dim, tensor_shape.Dimension): 473 dim = dim.value 474 if self.use_scale: 475 self.scale = self.add_weight( 476 name='scale', 477 shape=[dim], 478 initializer=init_ops.glorot_uniform_initializer(), 479 dtype=self.dtype, 480 trainable=True) 481 else: 482 self.scale = None 483 super(AdditiveAttention, self).build(input_shape) 484 485 def _calculate_scores(self, query, key): 486 """Calculates attention scores as a nonlinear sum of query and key. 487 488 Args: 489 query: Query tensor of shape `[batch_size, Tq, dim]`. 490 key: Key tensor of shape `[batch_size, Tv, dim]`. 491 Returns: 492 Tensor of shape `[batch_size, Tq, Tv]`. 493 """ 494 # Reshape tensors to enable broadcasting. 495 # Reshape into [batch_size, Tq, 1, dim]. 496 q_reshaped = array_ops.expand_dims(query, axis=-2) 497 # Reshape into [batch_size, 1, Tv, dim]. 498 k_reshaped = array_ops.expand_dims(key, axis=-3) 499 if self.use_scale: 500 scale = self.scale 501 else: 502 scale = 1. 503 return math_ops.reduce_sum( 504 scale * math_ops.tanh(q_reshaped + k_reshaped), axis=-1) 505 506 def get_config(self): 507 config = {'use_scale': self.use_scale} 508 base_config = super(AdditiveAttention, self).get_config() 509 return dict(list(base_config.items()) + list(config.items())) 510 511 512def _lower_triangular_mask(shape): 513 """Creates a lower-triangular boolean mask over the last 2 dimensions.""" 514 row_index = math_ops.cumsum( 515 array_ops.ones(shape=shape, dtype=dtypes.int32), axis=-2) 516 col_index = math_ops.cumsum( 517 array_ops.ones(shape=shape, dtype=dtypes.int32), axis=-1) 518 return math_ops.greater_equal(row_index, col_index) 519 520 521def _merge_masks(x, y): 522 if x is None: 523 return y 524 if y is None: 525 return x 526 return math_ops.logical_and(x, y) 527