1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Built-in activation functions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import six 21 22from tensorflow.python.keras import backend as K 23from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 24from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import nn 27from tensorflow.python.util import dispatch 28from tensorflow.python.util.tf_export import keras_export 29from tensorflow.python.keras.layers import advanced_activations 30 31# b/123041942 32# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras 33# layers, it gets serialized as 'softmax_v2' instead of 'softmax' as the 34# internal method name is returned in serialization. This results in errors in 35# model exporting and loading as Keras can't find any activation function with 36# the name of `softmax_v2`. 37 38# This dict maps the activation function name from its v2 version to its 39# canonical name. 40_TF_ACTIVATIONS_V2 = { 41 'softmax_v2': 'softmax', 42} 43 44 45@keras_export('keras.activations.softmax') 46@dispatch.add_dispatch_support 47def softmax(x, axis=-1): 48 """Softmax converts a real vector to a vector of categorical probabilities. 49 50 The elements of the output vector are in range (0, 1) and sum to 1. 51 52 Each vector is handled independently. The `axis` argument sets which axis 53 of the input the function is applied along. 54 55 Softmax is often used as the activation for the last 56 layer of a classification network because the result could be interpreted as 57 a probability distribution. 58 59 The softmax of each vector x is computed as 60 `exp(x) / tf.reduce_sum(exp(x))`. 61 62 The input values in are the log-odds of the resulting probability. 63 64 Args: 65 x : Input tensor. 66 axis: Integer, axis along which the softmax normalization is applied. 67 68 Returns: 69 Tensor, output of softmax transformation (all values are non-negative 70 and sum to 1). 71 72 Raises: 73 ValueError: In case `dim(x) == 1`. 74 """ 75 rank = x.shape.rank 76 if rank == 2: 77 output = nn.softmax(x) 78 elif rank > 2: 79 e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True)) 80 s = math_ops.reduce_sum(e, axis=axis, keepdims=True) 81 output = e / s 82 else: 83 raise ValueError('Cannot apply softmax to a tensor that is 1D. ' 84 'Received input: %s' % (x,)) 85 86 # Cache the logits to use for crossentropy loss. 87 output._keras_logits = x # pylint: disable=protected-access 88 return output 89 90 91@keras_export('keras.activations.elu') 92@dispatch.add_dispatch_support 93def elu(x, alpha=1.0): 94 """Exponential Linear Unit. 95 96 The exponential linear unit (ELU) with `alpha > 0` is: 97 `x` if `x > 0` and 98 `alpha * (exp(x) - 1)` if `x < 0` 99 The ELU hyperparameter `alpha` controls the value to which an 100 ELU saturates for negative net inputs. ELUs diminish the 101 vanishing gradient effect. 102 103 ELUs have negative values which pushes the mean of the activations 104 closer to zero. 105 Mean activations that are closer to zero enable faster learning as they 106 bring the gradient closer to the natural gradient. 107 ELUs saturate to a negative value when the argument gets smaller. 108 Saturation means a small derivative which decreases the variation 109 and the information that is propagated to the next layer. 110 111 Example Usage: 112 113 >>> import tensorflow as tf 114 >>> model = tf.keras.Sequential() 115 >>> model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='elu', 116 ... input_shape=(28, 28, 1))) 117 >>> model.add(tf.keras.layers.MaxPooling2D((2, 2))) 118 >>> model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='elu')) 119 >>> model.add(tf.keras.layers.MaxPooling2D((2, 2))) 120 >>> model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='elu')) 121 122 <tensorflow.python.keras.engine.sequential.Sequential object ...> 123 124 Args: 125 x: Input tensor. 126 alpha: A scalar, slope of negative section. `alpha` controls the value to 127 which an ELU saturates for negative net inputs. 128 129 Returns: 130 The exponential linear unit (ELU) activation function: `x` if `x > 0` and 131 `alpha * (exp(x) - 1)` if `x < 0`. 132 133 134 Reference: 135 [Fast and Accurate Deep Network Learning by Exponential Linear Units 136 (ELUs) (Clevert et al, 2016)](https://arxiv.org/abs/1511.07289) 137 """ 138 return K.elu(x, alpha) 139 140 141@keras_export('keras.activations.selu') 142@dispatch.add_dispatch_support 143def selu(x): 144 """Scaled Exponential Linear Unit (SELU). 145 146 The Scaled Exponential Linear Unit (SELU) activation function is defined as: 147 148 - `if x > 0: return scale * x` 149 - `if x < 0: return scale * alpha * (exp(x) - 1)` 150 151 where `alpha` and `scale` are pre-defined constants 152 (`alpha=1.67326324` and `scale=1.05070098`). 153 154 Basically, the SELU activation function multiplies `scale` (> 1) with the 155 output of the `tf.keras.activations.elu` function to ensure a slope larger 156 than one for positive inputs. 157 158 The values of `alpha` and `scale` are 159 chosen so that the mean and variance of the inputs are preserved 160 between two consecutive layers as long as the weights are initialized 161 correctly (see `tf.keras.initializers.LecunNormal` initializer) 162 and the number of input units is "large enough" 163 (see reference paper for more information). 164 165 Example Usage: 166 167 >>> num_classes = 10 # 10-class problem 168 >>> model = tf.keras.Sequential() 169 >>> model.add(tf.keras.layers.Dense(64, kernel_initializer='lecun_normal', 170 ... activation='selu')) 171 >>> model.add(tf.keras.layers.Dense(32, kernel_initializer='lecun_normal', 172 ... activation='selu')) 173 >>> model.add(tf.keras.layers.Dense(16, kernel_initializer='lecun_normal', 174 ... activation='selu')) 175 >>> model.add(tf.keras.layers.Dense(num_classes, activation='softmax')) 176 177 Args: 178 x: A tensor or variable to compute the activation function for. 179 180 Returns: 181 The scaled exponential unit activation: `scale * elu(x, alpha)`. 182 183 Notes: 184 - To be used together with the 185 `tf.keras.initializers.LecunNormal` initializer. 186 - To be used together with the dropout variant 187 `tf.keras.layers.AlphaDropout` (not regular dropout). 188 189 References: 190 - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515) 191 """ 192 return nn.selu(x) 193 194 195@keras_export('keras.activations.softplus') 196@dispatch.add_dispatch_support 197def softplus(x): 198 """Softplus activation function, `softplus(x) = log(exp(x) + 1)`. 199 200 Example Usage: 201 202 >>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32) 203 >>> b = tf.keras.activations.softplus(a) 204 >>> b.numpy() 205 array([2.0611537e-09, 3.1326166e-01, 6.9314718e-01, 1.3132616e+00, 206 2.0000000e+01], dtype=float32) 207 208 Args: 209 x: Input tensor. 210 211 Returns: 212 The softplus activation: `log(exp(x) + 1)`. 213 """ 214 return nn.softplus(x) 215 216 217@keras_export('keras.activations.softsign') 218@dispatch.add_dispatch_support 219def softsign(x): 220 """Softsign activation function, `softsign(x) = x / (abs(x) + 1)`. 221 222 Example Usage: 223 224 >>> a = tf.constant([-1.0, 0.0, 1.0], dtype = tf.float32) 225 >>> b = tf.keras.activations.softsign(a) 226 >>> b.numpy() 227 array([-0.5, 0. , 0.5], dtype=float32) 228 229 Args: 230 x: Input tensor. 231 232 Returns: 233 The softsign activation: `x / (abs(x) + 1)`. 234 """ 235 return nn.softsign(x) 236 237 238@keras_export('keras.activations.swish') 239@dispatch.add_dispatch_support 240def swish(x): 241 """Swish activation function, `swish(x) = x * sigmoid(x)`. 242 243 Swish activation function which returns `x*sigmoid(x)`. 244 It is a smooth, non-monotonic function that consistently matches 245 or outperforms ReLU on deep networks, it is unbounded above and 246 bounded below. 247 248 249 Example Usage: 250 251 >>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32) 252 >>> b = tf.keras.activations.swish(a) 253 >>> b.numpy() 254 array([-4.1223075e-08, -2.6894143e-01, 0.0000000e+00, 7.3105860e-01, 255 2.0000000e+01], dtype=float32) 256 257 Args: 258 x: Input tensor. 259 260 Returns: 261 The swish activation applied to `x` (see reference paper for details). 262 263 Reference: 264 - [Ramachandran et al., 2017](https://arxiv.org/abs/1710.05941) 265 """ 266 return nn.swish(x) 267 268 269@keras_export('keras.activations.relu') 270@dispatch.add_dispatch_support 271def relu(x, alpha=0., max_value=None, threshold=0): 272 """Applies the rectified linear unit activation function. 273 274 With default values, this returns the standard ReLU activation: 275 `max(x, 0)`, the element-wise maximum of 0 and the input tensor. 276 277 Modifying default parameters allows you to use non-zero thresholds, 278 change the max value of the activation, 279 and to use a non-zero multiple of the input for values below the threshold. 280 281 For example: 282 283 >>> foo = tf.constant([-10, -5, 0.0, 5, 10], dtype = tf.float32) 284 >>> tf.keras.activations.relu(foo).numpy() 285 array([ 0., 0., 0., 5., 10.], dtype=float32) 286 >>> tf.keras.activations.relu(foo, alpha=0.5).numpy() 287 array([-5. , -2.5, 0. , 5. , 10. ], dtype=float32) 288 >>> tf.keras.activations.relu(foo, max_value=5).numpy() 289 array([0., 0., 0., 5., 5.], dtype=float32) 290 >>> tf.keras.activations.relu(foo, threshold=5).numpy() 291 array([-0., -0., 0., 0., 10.], dtype=float32) 292 293 Args: 294 x: Input `tensor` or `variable`. 295 alpha: A `float` that governs the slope for values lower than the 296 threshold. 297 max_value: A `float` that sets the saturation threshold (the largest value 298 the function will return). 299 threshold: A `float` giving the threshold value of the activation function 300 below which values will be damped or set to zero. 301 302 Returns: 303 A `Tensor` representing the input tensor, 304 transformed by the relu activation function. 305 Tensor will be of the same shape and dtype of input `x`. 306 """ 307 return K.relu(x, alpha=alpha, max_value=max_value, threshold=threshold) 308 309 310@keras_export('keras.activations.gelu', v1=[]) 311@dispatch.add_dispatch_support 312def gelu(x, approximate=False): 313 """Applies the Gaussian error linear unit (GELU) activation function. 314 315 Gaussian error linear unit (GELU) computes 316 `x * P(X <= x)`, where `P(X) ~ N(0, 1)`. 317 The (GELU) nonlinearity weights inputs by their value, rather than gates 318 inputs by their sign as in ReLU. 319 320 For example: 321 322 >>> x = tf.constant([-3.0, -1.0, 0.0, 1.0, 3.0], dtype=tf.float32) 323 >>> y = tf.keras.activations.gelu(x) 324 >>> y.numpy() 325 array([-0.00404951, -0.15865529, 0. , 0.8413447 , 2.9959507 ], 326 dtype=float32) 327 >>> y = tf.keras.activations.gelu(x, approximate=True) 328 >>> y.numpy() 329 array([-0.00363752, -0.15880796, 0. , 0.841192 , 2.9963627 ], 330 dtype=float32) 331 332 Args: 333 x: Input tensor. 334 approximate: A `bool`, whether to enable approximation. 335 336 Returns: 337 The gaussian error linear activation: 338 `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))` 339 if `approximate` is `True` or 340 `x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`, 341 where `P(X) ~ N(0, 1)`, 342 if `approximate` is `False`. 343 344 Reference: 345 - [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415) 346 """ 347 return nn.gelu(x, approximate) 348 349 350@keras_export('keras.activations.tanh') 351@dispatch.add_dispatch_support 352def tanh(x): 353 """Hyperbolic tangent activation function. 354 355 For example: 356 357 >>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32) 358 >>> b = tf.keras.activations.tanh(a) 359 >>> b.numpy() 360 array([-0.9950547, -0.7615942, 0., 0.7615942, 0.9950547], dtype=float32) 361 362 Args: 363 x: Input tensor. 364 365 Returns: 366 Tensor of same shape and dtype of input `x`, with tanh activation: 367 `tanh(x) = sinh(x)/cosh(x) = ((exp(x) - exp(-x))/(exp(x) + exp(-x)))`. 368 """ 369 return nn.tanh(x) 370 371 372@keras_export('keras.activations.sigmoid') 373@dispatch.add_dispatch_support 374def sigmoid(x): 375 """Sigmoid activation function, `sigmoid(x) = 1 / (1 + exp(-x))`. 376 377 Applies the sigmoid activation function. For small values (<-5), 378 `sigmoid` returns a value close to zero, and for large values (>5) 379 the result of the function gets close to 1. 380 381 Sigmoid is equivalent to a 2-element Softmax, where the second element is 382 assumed to be zero. The sigmoid function always returns a value between 383 0 and 1. 384 385 For example: 386 387 >>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32) 388 >>> b = tf.keras.activations.sigmoid(a) 389 >>> b.numpy() 390 array([2.0611537e-09, 2.6894143e-01, 5.0000000e-01, 7.3105860e-01, 391 1.0000000e+00], dtype=float32) 392 393 Args: 394 x: Input tensor. 395 396 Returns: 397 Tensor with the sigmoid activation: `1 / (1 + exp(-x))`. 398 """ 399 output = nn.sigmoid(x) 400 # Cache the logits to use for crossentropy loss. 401 output._keras_logits = x # pylint: disable=protected-access 402 return output 403 404 405@keras_export('keras.activations.exponential') 406@dispatch.add_dispatch_support 407def exponential(x): 408 """Exponential activation function. 409 410 For example: 411 412 >>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32) 413 >>> b = tf.keras.activations.exponential(a) 414 >>> b.numpy() 415 array([0.04978707, 0.36787945, 1., 2.7182817 , 20.085537], dtype=float32) 416 417 Args: 418 x: Input tensor. 419 420 Returns: 421 Tensor with exponential activation: `exp(x)`. 422 """ 423 return math_ops.exp(x) 424 425 426@keras_export('keras.activations.hard_sigmoid') 427@dispatch.add_dispatch_support 428def hard_sigmoid(x): 429 """Hard sigmoid activation function. 430 431 A faster approximation of the sigmoid activation. 432 433 For example: 434 435 >>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32) 436 >>> b = tf.keras.activations.hard_sigmoid(a) 437 >>> b.numpy() 438 array([0. , 0.3, 0.5, 0.7, 1. ], dtype=float32) 439 440 Args: 441 x: Input tensor. 442 443 Returns: 444 The hard sigmoid activation, defined as: 445 446 - `if x < -2.5: return 0` 447 - `if x > 2.5: return 1` 448 - `if -2.5 <= x <= 2.5: return 0.2 * x + 0.5` 449 """ 450 return K.hard_sigmoid(x) 451 452 453@keras_export('keras.activations.linear') 454@dispatch.add_dispatch_support 455def linear(x): 456 """Linear activation function (pass-through). 457 458 For example: 459 460 >>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32) 461 >>> b = tf.keras.activations.linear(a) 462 >>> b.numpy() 463 array([-3., -1., 0., 1., 3.], dtype=float32) 464 465 Args: 466 x: Input tensor. 467 468 Returns: 469 The input, unmodified. 470 """ 471 return x 472 473 474@keras_export('keras.activations.serialize') 475@dispatch.add_dispatch_support 476def serialize(activation): 477 """Returns the string identifier of an activation function. 478 479 Args: 480 activation : Function object. 481 482 Returns: 483 String denoting the name attribute of the input function 484 485 For example: 486 487 >>> tf.keras.activations.serialize(tf.keras.activations.tanh) 488 'tanh' 489 >>> tf.keras.activations.serialize(tf.keras.activations.sigmoid) 490 'sigmoid' 491 >>> tf.keras.activations.serialize('abcd') 492 Traceback (most recent call last): 493 ... 494 ValueError: ('Cannot serialize', 'abcd') 495 496 Raises: 497 ValueError: The input function is not a valid one. 498 """ 499 if (hasattr(activation, '__name__') and 500 activation.__name__ in _TF_ACTIVATIONS_V2): 501 return _TF_ACTIVATIONS_V2[activation.__name__] 502 return serialize_keras_object(activation) 503 504 505@keras_export('keras.activations.deserialize') 506@dispatch.add_dispatch_support 507def deserialize(name, custom_objects=None): 508 """Returns activation function given a string identifier. 509 510 Args: 511 name: The name of the activation function. 512 custom_objects: Optional `{function_name: function_obj}` 513 dictionary listing user-provided activation functions. 514 515 Returns: 516 Corresponding activation function. 517 518 For example: 519 520 >>> tf.keras.activations.deserialize('linear') 521 <function linear at 0x1239596a8> 522 >>> tf.keras.activations.deserialize('sigmoid') 523 <function sigmoid at 0x123959510> 524 >>> tf.keras.activations.deserialize('abcd') 525 Traceback (most recent call last): 526 ... 527 ValueError: Unknown activation function:abcd 528 529 Raises: 530 ValueError: `Unknown activation function` if the input string does not 531 denote any defined Tensorflow activation function. 532 """ 533 globs = globals() 534 535 # only replace missing activations 536 advanced_activations_globs = advanced_activations.get_globals() 537 for key, val in advanced_activations_globs.items(): 538 if key not in globs: 539 globs[key] = val 540 541 return deserialize_keras_object( 542 name, 543 module_objects=globs, 544 custom_objects=custom_objects, 545 printable_module_name='activation function') 546 547 548@keras_export('keras.activations.get') 549@dispatch.add_dispatch_support 550def get(identifier): 551 """Returns function. 552 553 Args: 554 identifier: Function or string 555 556 Returns: 557 Function corresponding to the input string or input function. 558 559 For example: 560 561 >>> tf.keras.activations.get('softmax') 562 <function softmax at 0x1222a3d90> 563 >>> tf.keras.activations.get(tf.keras.activations.softmax) 564 <function softmax at 0x1222a3d90> 565 >>> tf.keras.activations.get(None) 566 <function linear at 0x1239596a8> 567 >>> tf.keras.activations.get(abs) 568 <built-in function abs> 569 >>> tf.keras.activations.get('abcd') 570 Traceback (most recent call last): 571 ... 572 ValueError: Unknown activation function:abcd 573 574 Raises: 575 ValueError: Input is an unknown function or string, i.e., the input does 576 not denote any defined function. 577 """ 578 if identifier is None: 579 return linear 580 if isinstance(identifier, six.string_types): 581 identifier = str(identifier) 582 return deserialize(identifier) 583 elif isinstance(identifier, dict): 584 return deserialize(identifier) 585 elif callable(identifier): 586 return identifier 587 else: 588 raise TypeError( 589 'Could not interpret activation function identifier: {}'.format( 590 identifier)) 591