1# Lint as: python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Keras-based attention layer.""" 17# pylint: disable=g-classes-have-attributes 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import math 24import string 25 26import numpy as np 27 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.keras import constraints 30from tensorflow.python.keras import initializers 31from tensorflow.python.keras import regularizers 32from tensorflow.python.keras.engine.base_layer import Layer 33from tensorflow.python.keras.layers import advanced_activations 34from tensorflow.python.keras.layers import core 35from tensorflow.python.keras.layers import einsum_dense 36from tensorflow.python.keras.utils import tf_utils 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import special_math_ops 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.util.tf_export import keras_export 42 43 44_CHR_IDX = string.ascii_lowercase 45 46 47def _build_attention_equation(rank, attn_axes): 48 """Builds einsum equations for the attention computation. 49 50 Query, key, value inputs after projection are expected to have the shape as: 51 (bs, <non-attention dims>, <attention dims>, num_heads, channels). 52 bs and <non-attention dims> are treated as <batch dims>. 53 The attention operations can be generalized: 54 (1) Query-key dot product: 55 (<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>, 56 <key attention dims>, num_heads, channels) -> (<batch dims>, 57 num_heads, <query attention dims>, <key attention dims>) 58 (2) Combination: 59 (<batch dims>, num_heads, <query attention dims>, <key attention dims>), 60 (<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>, 61 <query attention dims>, num_heads, channels) 62 63 Args: 64 rank: the rank of query, key, value tensors. 65 attn_axes: a list/tuple of axes, [-1, rank), that will do attention. 66 67 Returns: 68 Einsum equations. 69 """ 70 target_notation = _CHR_IDX[:rank] 71 # `batch_dims` includes the head dim. 72 batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) 73 letter_offset = rank 74 source_notation = "" 75 for i in range(rank): 76 if i in batch_dims or i == rank - 1: 77 source_notation += target_notation[i] 78 else: 79 source_notation += _CHR_IDX[letter_offset] 80 letter_offset += 1 81 82 product_notation = "".join([target_notation[i] for i in batch_dims] + 83 [target_notation[i] for i in attn_axes] + 84 [source_notation[i] for i in attn_axes]) 85 dot_product_equation = "%s,%s->%s" % (source_notation, target_notation, 86 product_notation) 87 attn_scores_rank = len(product_notation) 88 combine_equation = "%s,%s->%s" % (product_notation, source_notation, 89 target_notation) 90 return dot_product_equation, combine_equation, attn_scores_rank 91 92 93def _build_proj_equation(free_dims, bound_dims, output_dims): 94 """Builds an einsum equation for projections inside multi-head attention.""" 95 input_str = "" 96 kernel_str = "" 97 output_str = "" 98 bias_axes = "" 99 letter_offset = 0 100 for i in range(free_dims): 101 char = _CHR_IDX[i + letter_offset] 102 input_str += char 103 output_str += char 104 105 letter_offset += free_dims 106 for i in range(bound_dims): 107 char = _CHR_IDX[i + letter_offset] 108 input_str += char 109 kernel_str += char 110 111 letter_offset += bound_dims 112 for i in range(output_dims): 113 char = _CHR_IDX[i + letter_offset] 114 kernel_str += char 115 output_str += char 116 bias_axes += char 117 equation = "%s,%s->%s" % (input_str, kernel_str, output_str) 118 119 return equation, bias_axes, len(output_str) 120 121 122def _get_output_shape(output_rank, known_last_dims): 123 return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) 124 125 126@keras_export("keras.layers.MultiHeadAttention") 127class MultiHeadAttention(Layer): 128 """MultiHeadAttention layer. 129 130 This is an implementation of multi-headed attention based on "Attention 131 is all you Need". If `query`, `key,` `value` are the same, then 132 this is self-attention. Each timestep in `query` attends to the 133 corresponding sequence in `key`, and returns a fixed-width vector. 134 135 This layer first projects `query`, `key` and `value`. These are 136 (effectively) a list of tensors of length `num_attention_heads`, where the 137 corresponding shapes are [batch_size, <query dimensions>, key_dim], 138 [batch_size, <key/value dimensions>, key_dim], 139 [batch_size, <key/value dimensions>, value_dim]. 140 141 Then, the query and key tensors are dot-producted and scaled. These are 142 softmaxed to obtain attention probabilities. The value tensors are then 143 interpolated by these probabilities, then concatenated back to a single 144 tensor. 145 146 Finally, the result tensor with the last dimension as value_dim can take an 147 linear projection and return. 148 149 Examples: 150 151 Performs 1D cross-attention over two sequence inputs with an attention mask. 152 Returns the additional attention weights over heads. 153 154 >>> layer = MultiHeadAttention(num_heads=2, key_dim=2) 155 >>> target = tf.keras.Input(shape=[8, 16]) 156 >>> source = tf.keras.Input(shape=[4, 16]) 157 >>> output_tensor, weights = layer(target, source, 158 ... return_attention_scores=True) 159 >>> print(output_tensor.shape) 160 (None, 8, 16) 161 >>> print(weights.shape) 162 (None, 2, 8, 4) 163 164 Performs 2D self-attention over a 5D input tensor on axes 2 and 3. 165 166 >>> layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3)) 167 >>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16]) 168 >>> output_tensor = layer(input_tensor, input_tensor) 169 >>> print(output_tensor.shape) 170 (None, 5, 3, 4, 16) 171 172 Args: 173 num_heads: Number of attention heads. 174 key_dim: Size of each attention head for query and key. 175 value_dim: Size of each attention head for value. 176 dropout: Dropout probability. 177 use_bias: Boolean, whether the dense layers use bias vectors/matrices. 178 output_shape: The expected shape of an output tensor, besides the batch and 179 sequence dims. If not specified, projects back to the key feature dim. 180 attention_axes: axes over which the attention is applied. `None` means 181 attention over all axes, but batch, heads, and features. 182 kernel_initializer: Initializer for dense layer kernels. 183 bias_initializer: Initializer for dense layer biases. 184 kernel_regularizer: Regularizer for dense layer kernels. 185 bias_regularizer: Regularizer for dense layer biases. 186 activity_regularizer: Regularizer for dense layer activity. 187 kernel_constraint: Constraint for dense layer kernels. 188 bias_constraint: Constraint for dense layer kernels. 189 190 Call arguments: 191 query: Query `Tensor` of shape `[B, T, dim]`. 192 value: Value `Tensor` of shape `[B, S, dim]`. 193 key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use 194 `value` for both `key` and `value`, which is the most common case. 195 attention_mask: a boolean mask of shape `[B, T, S]`, that prevents 196 attention to certain positions. The boolean mask specifies which query 197 elements can attend to which key elements, 1 indicates attention and 0 198 indicates no attention. Broadcasting can happen for the missing batch 199 dimensions and the head dimension. 200 return_attention_scores: A boolean to indicate whether the output should 201 be attention output if True, or (attention_output, attention_scores) if 202 False. Defaults to False. 203 training: Python boolean indicating whether the layer should behave in 204 training mode (adding dropout) or in inference mode (no dropout). 205 Defaults to either using the training mode of the parent layer/model, 206 or False (inference) if there is no parent layer. 207 208 Returns: 209 attention_output: The result of the computation, of shape [B, T, E], 210 where `T` is for target sequence shapes and `E` is the query input last 211 dimension if `output_shape` is `None`. Otherwise, the multi-head outputs 212 are project to the shape specified by `output_shape`. 213 attention_scores: [Optional] multi-head attention coeffients over 214 attention axes. 215 """ 216 217 def __init__(self, 218 num_heads, 219 key_dim, 220 value_dim=None, 221 dropout=0.0, 222 use_bias=True, 223 output_shape=None, 224 attention_axes=None, 225 kernel_initializer="glorot_uniform", 226 bias_initializer="zeros", 227 kernel_regularizer=None, 228 bias_regularizer=None, 229 activity_regularizer=None, 230 kernel_constraint=None, 231 bias_constraint=None, 232 **kwargs): 233 super(MultiHeadAttention, self).__init__(**kwargs) 234 self._num_heads = num_heads 235 self._key_dim = key_dim 236 self._value_dim = value_dim if value_dim else key_dim 237 self._dropout = dropout 238 self._use_bias = use_bias 239 self._output_shape = output_shape 240 self._kernel_initializer = initializers.get(kernel_initializer) 241 self._bias_initializer = initializers.get(bias_initializer) 242 self._kernel_regularizer = regularizers.get(kernel_regularizer) 243 self._bias_regularizer = regularizers.get(bias_regularizer) 244 self._kernel_constraint = constraints.get(kernel_constraint) 245 self._bias_constraint = constraints.get(bias_constraint) 246 if attention_axes is not None and not isinstance(attention_axes, 247 collections.abc.Sized): 248 self._attention_axes = (attention_axes,) 249 else: 250 self._attention_axes = attention_axes 251 self._built_from_signature = False 252 self._query_shape, self._key_shape, self._value_shape = None, None, None 253 254 def get_config(self): 255 config = { 256 "num_heads": 257 self._num_heads, 258 "key_dim": 259 self._key_dim, 260 "value_dim": 261 self._value_dim, 262 "dropout": 263 self._dropout, 264 "use_bias": 265 self._use_bias, 266 "output_shape": 267 self._output_shape, 268 "attention_axes": 269 self._attention_axes, 270 "kernel_initializer": 271 initializers.serialize(self._kernel_initializer), 272 "bias_initializer": 273 initializers.serialize(self._bias_initializer), 274 "kernel_regularizer": 275 regularizers.serialize(self._kernel_regularizer), 276 "bias_regularizer": 277 regularizers.serialize(self._bias_regularizer), 278 "activity_regularizer": 279 regularizers.serialize(self._activity_regularizer), 280 "kernel_constraint": 281 constraints.serialize(self._kernel_constraint), 282 "bias_constraint": 283 constraints.serialize(self._bias_constraint), 284 "query_shape": self._query_shape, 285 "key_shape": self._key_shape, 286 "value_shape": self._value_shape, 287 } 288 base_config = super(MultiHeadAttention, self).get_config() 289 return dict(list(base_config.items()) + list(config.items())) 290 291 @classmethod 292 def from_config(cls, config): 293 # If the layer has a different build() function from the Keras default, 294 # we need to trigger the customized build to create weights. 295 query_shape = config.pop("query_shape") 296 key_shape = config.pop("key_shape") 297 value_shape = config.pop("value_shape") 298 layer = cls(**config) 299 if None in [query_shape, key_shape, value_shape]: 300 logging.warning( 301 "One of the input shape is missing. They should be " 302 "memorized when the layer was serialized. " 303 "%s is created without weights.", 304 str(cls)) 305 else: 306 layer._build_from_signature(query_shape, value_shape, key_shape) # pylint: disable=protected-access 307 return layer 308 309 def _build_from_signature(self, query, value, key=None): 310 """Builds layers and variables. 311 312 Once the method is called, self._built_from_signature will be set to True. 313 314 Args: 315 query: query tensor or TensorShape. 316 value: value tensor or TensorShape. 317 key: key tensor or TensorShape. 318 """ 319 self._built_from_signature = True 320 if hasattr(query, "shape"): 321 self._query_shape = tensor_shape.TensorShape(query.shape) 322 else: 323 self._query_shape = tensor_shape.TensorShape(query) 324 if hasattr(value, "shape"): 325 self._value_shape = tensor_shape.TensorShape(value.shape) 326 else: 327 self._value_shape = tensor_shape.TensorShape(value) 328 if key is None: 329 self._key_shape = self._value_shape 330 elif hasattr(key, "shape"): 331 self._key_shape = tensor_shape.TensorShape(key.shape) 332 else: 333 self._key_shape = tensor_shape.TensorShape(key) 334 335 common_kwargs = dict( 336 kernel_initializer=self._kernel_initializer, 337 bias_initializer=self._bias_initializer, 338 kernel_regularizer=self._kernel_regularizer, 339 bias_regularizer=self._bias_regularizer, 340 activity_regularizer=self._activity_regularizer, 341 kernel_constraint=self._kernel_constraint, 342 bias_constraint=self._bias_constraint) 343 # Any setup work performed only once should happen in an `init_scope` 344 # to avoid creating symbolic Tensors that will later pollute any eager 345 # operations. 346 with tf_utils.maybe_init_scope(self): 347 free_dims = self._query_shape.rank - 1 348 einsum_equation, bias_axes, output_rank = _build_proj_equation( 349 free_dims, bound_dims=1, output_dims=2) 350 self._query_dense = einsum_dense.EinsumDense( 351 einsum_equation, 352 output_shape=_get_output_shape(output_rank - 1, 353 [self._num_heads, self._key_dim]), 354 bias_axes=bias_axes if self._use_bias else None, 355 name="query", 356 **common_kwargs) 357 einsum_equation, bias_axes, output_rank = _build_proj_equation( 358 self._key_shape.rank - 1, bound_dims=1, output_dims=2) 359 self._key_dense = einsum_dense.EinsumDense( 360 einsum_equation, 361 output_shape=_get_output_shape(output_rank - 1, 362 [self._num_heads, self._key_dim]), 363 bias_axes=bias_axes if self._use_bias else None, 364 name="key", 365 **common_kwargs) 366 einsum_equation, bias_axes, output_rank = _build_proj_equation( 367 self._value_shape.rank - 1, bound_dims=1, output_dims=2) 368 self._value_dense = einsum_dense.EinsumDense( 369 einsum_equation, 370 output_shape=_get_output_shape(output_rank - 1, 371 [self._num_heads, self._value_dim]), 372 bias_axes=bias_axes if self._use_bias else None, 373 name="value", 374 **common_kwargs) 375 376 # Builds the attention computations for multi-head dot product attention. 377 # These computations could be wrapped into the keras attention layer once 378 # it support mult-head einsum computations. 379 self._build_attention(output_rank) 380 if self._output_shape: 381 if not isinstance(self._output_shape, collections.abc.Sized): 382 output_shape = [self._output_shape] 383 else: 384 output_shape = self._output_shape 385 else: 386 output_shape = [self._query_shape[-1]] 387 einsum_equation, bias_axes, output_rank = _build_proj_equation( 388 free_dims, bound_dims=2, output_dims=len(output_shape)) 389 self._output_dense = einsum_dense.EinsumDense( 390 einsum_equation, 391 output_shape=_get_output_shape(output_rank - 1, output_shape), 392 bias_axes=bias_axes if self._use_bias else None, 393 name="attention_output", 394 **common_kwargs) 395 396 def _build_attention(self, rank): 397 """Builds multi-head dot-product attention computations. 398 399 This function builds attributes necessary for `_compute_attention` to 400 costomize attention computation to replace the default dot-product 401 attention. 402 403 Args: 404 rank: the rank of query, key, value tensors. 405 """ 406 if self._attention_axes is None: 407 self._attention_axes = tuple(range(1, rank - 2)) 408 else: 409 self._attention_axes = tuple(self._attention_axes) 410 self._dot_product_equation, self._combine_equation, attn_scores_rank = ( 411 _build_attention_equation(rank, attn_axes=self._attention_axes)) 412 norm_axes = tuple( 413 range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) 414 self._softmax = advanced_activations.Softmax(axis=norm_axes) 415 self._dropout_layer = core.Dropout(rate=self._dropout) 416 417 def _masked_softmax(self, attention_scores, attention_mask=None): 418 # Normalize the attention scores to probabilities. 419 # `attention_scores` = [B, N, T, S] 420 if attention_mask is not None: 421 # The expand dim happens starting from the `num_heads` dimension, 422 # (<batch_dims>, num_heads, <query_attention_dims, key_attention_dims>) 423 mask_expansion_axes = [-len(self._attention_axes) * 2 - 1] 424 for _ in range(len(attention_scores.shape) - len(attention_mask.shape)): 425 attention_mask = array_ops.expand_dims( 426 attention_mask, axis=mask_expansion_axes) 427 return self._softmax(attention_scores, attention_mask) 428 429 def _compute_attention(self, 430 query, 431 key, 432 value, 433 attention_mask=None, 434 training=None): 435 """Applies Dot-product attention with query, key, value tensors. 436 437 This function defines the computation inside `call` with projected 438 multi-head Q, K, V inputs. Users can override this function for customized 439 attention implementation. 440 441 Args: 442 query: Projected query `Tensor` of shape `[B, T, N, key_dim]`. 443 key: Projected key `Tensor` of shape `[B, T, N, key_dim]`. 444 value: Projected value `Tensor` of shape `[B, T, N, value_dim]`. 445 attention_mask: a boolean mask of shape `[B, T, S]`, that prevents 446 attention to certain positions. 447 training: Python boolean indicating whether the layer should behave in 448 training mode (adding dropout) or in inference mode (doing nothing). 449 450 Returns: 451 attention_output: Multi-headed outputs of attention computation. 452 attention_scores: Multi-headed attention weights. 453 """ 454 # Note: Applying scalar multiply at the smaller end of einsum improves 455 # XLA performance, but may introduce slight numeric differences in 456 # the Transformer attention head. 457 query = math_ops.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) 458 459 # Take the dot product between "query" and "key" to get the raw 460 # attention scores. 461 attention_scores = special_math_ops.einsum(self._dot_product_equation, key, 462 query) 463 464 attention_scores = self._masked_softmax(attention_scores, attention_mask) 465 466 # This is actually dropping out entire tokens to attend to, which might 467 # seem a bit unusual, but is taken from the original Transformer paper. 468 attention_scores_dropout = self._dropout_layer( 469 attention_scores, training=training) 470 471 # `context_layer` = [B, T, N, H] 472 attention_output = special_math_ops.einsum(self._combine_equation, 473 attention_scores_dropout, value) 474 return attention_output, attention_scores 475 476 def call(self, 477 query, 478 value, 479 key=None, 480 attention_mask=None, 481 return_attention_scores=False, 482 training=None): 483 if not self._built_from_signature: 484 self._build_from_signature(query=query, value=value, key=key) 485 if key is None: 486 key = value 487 488 # N = `num_attention_heads` 489 # H = `size_per_head` 490 # `query` = [B, T, N ,H] 491 query = self._query_dense(query) 492 493 # `key` = [B, S, N, H] 494 key = self._key_dense(key) 495 496 # `value` = [B, S, N, H] 497 value = self._value_dense(value) 498 499 attention_output, attention_scores = self._compute_attention( 500 query, key, value, attention_mask, training) 501 attention_output = self._output_dense(attention_output) 502 503 if return_attention_scores: 504 return attention_output, attention_scores 505 return attention_output 506