1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Core Keras layers. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import sys 23import types as python_types 24import warnings 25 26import numpy as np 27 28from tensorflow.python.eager import context 29from tensorflow.python.framework import common_shapes 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.keras import activations 34from tensorflow.python.keras import backend as K 35from tensorflow.python.keras import constraints 36from tensorflow.python.keras import initializers 37from tensorflow.python.keras import regularizers 38from tensorflow.python.keras.engine.base_layer import Layer 39from tensorflow.python.keras.engine.input_spec import InputSpec 40from tensorflow.python.keras.utils import conv_utils 41from tensorflow.python.keras.utils import generic_utils 42from tensorflow.python.keras.utils import tf_utils 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import gen_math_ops 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import nn 47from tensorflow.python.ops import nn_ops 48from tensorflow.python.ops import standard_ops 49from tensorflow.python.ops import variable_scope 50from tensorflow.python.util import nest 51from tensorflow.python.util.tf_export import keras_export 52 53 54@keras_export('keras.layers.Masking') 55class Masking(Layer): 56 """Masks a sequence by using a mask value to skip timesteps. 57 58 For each timestep in the input tensor (dimension #1 in the tensor), 59 if all values in the input tensor at that timestep 60 are equal to `mask_value`, then the timestep will be masked (skipped) 61 in all downstream layers (as long as they support masking). 62 63 If any downstream layer does not support masking yet receives such 64 an input mask, an exception will be raised. 65 66 Example: 67 68 Consider a Numpy data array `x` of shape `(samples, timesteps, features)`, 69 to be fed to an LSTM layer. 70 You want to mask timestep #3 and #5 because you lack data for 71 these timesteps. You can: 72 73 - Set `x[:, 3, :] = 0.` and `x[:, 5, :] = 0.` 74 - Insert a `Masking` layer with `mask_value=0.` before the LSTM layer: 75 76 ```python 77 model = Sequential() 78 model.add(Masking(mask_value=0., input_shape=(timesteps, features))) 79 model.add(LSTM(32)) 80 ``` 81 """ 82 83 def __init__(self, mask_value=0., **kwargs): 84 super(Masking, self).__init__(**kwargs) 85 self.supports_masking = True 86 self.mask_value = mask_value 87 self._compute_output_and_mask_jointly = True 88 89 def compute_mask(self, inputs, mask=None): 90 return K.any(math_ops.not_equal(inputs, self.mask_value), axis=-1) 91 92 def call(self, inputs): 93 boolean_mask = K.any( 94 math_ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True) 95 outputs = inputs * math_ops.cast(boolean_mask, inputs.dtype) 96 # Compute the mask and outputs simultaneously. 97 outputs._keras_mask = array_ops.squeeze(boolean_mask, axis=-1) # pylint: disable=protected-access 98 return outputs 99 100 def compute_output_shape(self, input_shape): 101 return input_shape 102 103 def get_config(self): 104 config = {'mask_value': self.mask_value} 105 base_config = super(Masking, self).get_config() 106 return dict(list(base_config.items()) + list(config.items())) 107 108 109@keras_export('keras.layers.Dropout') 110class Dropout(Layer): 111 """Applies Dropout to the input. 112 113 Dropout consists in randomly setting 114 a fraction `rate` of input units to 0 at each update during training time, 115 which helps prevent overfitting. 116 117 Arguments: 118 rate: Float between 0 and 1. Fraction of the input units to drop. 119 noise_shape: 1D integer tensor representing the shape of the 120 binary dropout mask that will be multiplied with the input. 121 For instance, if your inputs have shape 122 `(batch_size, timesteps, features)` and 123 you want the dropout mask to be the same for all timesteps, 124 you can use `noise_shape=(batch_size, 1, features)`. 125 seed: A Python integer to use as random seed. 126 127 Call arguments: 128 inputs: Input tensor (of any rank). 129 training: Python boolean indicating whether the layer should behave in 130 training mode (adding dropout) or in inference mode (doing nothing). 131 """ 132 133 def __init__(self, rate, noise_shape=None, seed=None, **kwargs): 134 super(Dropout, self).__init__(**kwargs) 135 self.rate = rate 136 self.noise_shape = noise_shape 137 self.seed = seed 138 self.supports_masking = True 139 140 def _get_noise_shape(self, inputs): 141 # Subclasses of `Dropout` may implement `_get_noise_shape(self, inputs)`, 142 # which will override `self.noise_shape`, and allows for custom noise 143 # shapes with dynamically sized inputs. 144 if self.noise_shape is None: 145 return self.noise_shape 146 return nn_ops._get_noise_shape(inputs, self.noise_shape) # pylint: disable=protected-access 147 148 def call(self, inputs, training=None): 149 if training is None: 150 training = K.learning_phase() 151 152 def dropped_inputs(): 153 return nn.dropout( 154 inputs, 155 noise_shape=self._get_noise_shape(inputs), 156 seed=self.seed, 157 rate=self.rate) 158 159 output = tf_utils.smart_cond(training, 160 dropped_inputs, 161 lambda: array_ops.identity(inputs)) 162 return output 163 164 def compute_output_shape(self, input_shape): 165 return input_shape 166 167 def get_config(self): 168 config = { 169 'rate': self.rate, 170 'noise_shape': self.noise_shape, 171 'seed': self.seed 172 } 173 base_config = super(Dropout, self).get_config() 174 return dict(list(base_config.items()) + list(config.items())) 175 176 177@keras_export('keras.layers.SpatialDropout1D') 178class SpatialDropout1D(Dropout): 179 """Spatial 1D version of Dropout. 180 181 This version performs the same function as Dropout, however it drops 182 entire 1D feature maps instead of individual elements. If adjacent frames 183 within feature maps are strongly correlated (as is normally the case in 184 early convolution layers) then regular dropout will not regularize the 185 activations and will otherwise just result in an effective learning rate 186 decrease. In this case, SpatialDropout1D will help promote independence 187 between feature maps and should be used instead. 188 189 Arguments: 190 rate: Float between 0 and 1. Fraction of the input units to drop. 191 192 Call arguments: 193 inputs: A 3D tensor. 194 training: Python boolean indicating whether the layer should behave in 195 training mode (adding dropout) or in inference mode (doing nothing). 196 197 Input shape: 198 3D tensor with shape: 199 `(samples, timesteps, channels)` 200 201 Output shape: 202 Same as input. 203 204 References: 205 - [Efficient Object Localization Using Convolutional 206 Networks](https://arxiv.org/abs/1411.4280) 207 """ 208 209 def __init__(self, rate, **kwargs): 210 super(SpatialDropout1D, self).__init__(rate, **kwargs) 211 self.input_spec = InputSpec(ndim=3) 212 213 def _get_noise_shape(self, inputs): 214 input_shape = array_ops.shape(inputs) 215 noise_shape = (input_shape[0], 1, input_shape[2]) 216 return noise_shape 217 218 219@keras_export('keras.layers.SpatialDropout2D') 220class SpatialDropout2D(Dropout): 221 """Spatial 2D version of Dropout. 222 223 This version performs the same function as Dropout, however it drops 224 entire 2D feature maps instead of individual elements. If adjacent pixels 225 within feature maps are strongly correlated (as is normally the case in 226 early convolution layers) then regular dropout will not regularize the 227 activations and will otherwise just result in an effective learning rate 228 decrease. In this case, SpatialDropout2D will help promote independence 229 between feature maps and should be used instead. 230 231 Arguments: 232 rate: Float between 0 and 1. Fraction of the input units to drop. 233 data_format: 'channels_first' or 'channels_last'. 234 In 'channels_first' mode, the channels dimension 235 (the depth) is at index 1, 236 in 'channels_last' mode is it at index 3. 237 It defaults to the `image_data_format` value found in your 238 Keras config file at `~/.keras/keras.json`. 239 If you never set it, then it will be "channels_last". 240 241 Call arguments: 242 inputs: A 4D tensor. 243 training: Python boolean indicating whether the layer should behave in 244 training mode (adding dropout) or in inference mode (doing nothing). 245 246 Input shape: 247 4D tensor with shape: 248 `(samples, channels, rows, cols)` if data_format='channels_first' 249 or 4D tensor with shape: 250 `(samples, rows, cols, channels)` if data_format='channels_last'. 251 252 Output shape: 253 Same as input. 254 255 References: 256 - [Efficient Object Localization Using Convolutional 257 Networks](https://arxiv.org/abs/1411.4280) 258 """ 259 260 def __init__(self, rate, data_format=None, **kwargs): 261 super(SpatialDropout2D, self).__init__(rate, **kwargs) 262 if data_format is None: 263 data_format = K.image_data_format() 264 if data_format not in {'channels_last', 'channels_first'}: 265 raise ValueError('data_format must be in ' 266 '{"channels_last", "channels_first"}') 267 self.data_format = data_format 268 self.input_spec = InputSpec(ndim=4) 269 270 def _get_noise_shape(self, inputs): 271 input_shape = array_ops.shape(inputs) 272 if self.data_format == 'channels_first': 273 return (input_shape[0], input_shape[1], 1, 1) 274 elif self.data_format == 'channels_last': 275 return (input_shape[0], 1, 1, input_shape[3]) 276 277 278@keras_export('keras.layers.SpatialDropout3D') 279class SpatialDropout3D(Dropout): 280 """Spatial 3D version of Dropout. 281 282 This version performs the same function as Dropout, however it drops 283 entire 3D feature maps instead of individual elements. If adjacent voxels 284 within feature maps are strongly correlated (as is normally the case in 285 early convolution layers) then regular dropout will not regularize the 286 activations and will otherwise just result in an effective learning rate 287 decrease. In this case, SpatialDropout3D will help promote independence 288 between feature maps and should be used instead. 289 290 Arguments: 291 rate: Float between 0 and 1. Fraction of the input units to drop. 292 data_format: 'channels_first' or 'channels_last'. 293 In 'channels_first' mode, the channels dimension (the depth) 294 is at index 1, in 'channels_last' mode is it at index 4. 295 It defaults to the `image_data_format` value found in your 296 Keras config file at `~/.keras/keras.json`. 297 If you never set it, then it will be "channels_last". 298 299 Call arguments: 300 inputs: A 5D tensor. 301 training: Python boolean indicating whether the layer should behave in 302 training mode (adding dropout) or in inference mode (doing nothing). 303 304 Input shape: 305 5D tensor with shape: 306 `(samples, channels, dim1, dim2, dim3)` if data_format='channels_first' 307 or 5D tensor with shape: 308 `(samples, dim1, dim2, dim3, channels)` if data_format='channels_last'. 309 310 Output shape: 311 Same as input. 312 313 References: 314 - [Efficient Object Localization Using Convolutional 315 Networks](https://arxiv.org/abs/1411.4280) 316 """ 317 318 def __init__(self, rate, data_format=None, **kwargs): 319 super(SpatialDropout3D, self).__init__(rate, **kwargs) 320 if data_format is None: 321 data_format = K.image_data_format() 322 if data_format not in {'channels_last', 'channels_first'}: 323 raise ValueError('data_format must be in ' 324 '{"channels_last", "channels_first"}') 325 self.data_format = data_format 326 self.input_spec = InputSpec(ndim=5) 327 328 def _get_noise_shape(self, inputs): 329 input_shape = array_ops.shape(inputs) 330 if self.data_format == 'channels_first': 331 return (input_shape[0], input_shape[1], 1, 1, 1) 332 elif self.data_format == 'channels_last': 333 return (input_shape[0], 1, 1, 1, input_shape[4]) 334 335 336@keras_export('keras.layers.Activation') 337class Activation(Layer): 338 """Applies an activation function to an output. 339 340 Arguments: 341 activation: Activation function, such as `tf.nn.relu`, or string name of 342 built-in activation function, such as "relu". 343 344 Input shape: 345 Arbitrary. Use the keyword argument `input_shape` 346 (tuple of integers, does not include the samples axis) 347 when using this layer as the first layer in a model. 348 349 Output shape: 350 Same shape as input. 351 """ 352 353 def __init__(self, activation, **kwargs): 354 super(Activation, self).__init__(**kwargs) 355 self.supports_masking = True 356 self.activation = activations.get(activation) 357 358 def call(self, inputs): 359 return self.activation(inputs) 360 361 def compute_output_shape(self, input_shape): 362 return input_shape 363 364 def get_config(self): 365 config = {'activation': activations.serialize(self.activation)} 366 base_config = super(Activation, self).get_config() 367 return dict(list(base_config.items()) + list(config.items())) 368 369 370@keras_export('keras.layers.Reshape') 371class Reshape(Layer): 372 """Reshapes an output to a certain shape. 373 374 Arguments: 375 target_shape: Target shape. Tuple of integers, 376 does not include the samples dimension (batch size). 377 378 Input shape: 379 Arbitrary, although all dimensions in the input shaped must be fixed. 380 Use the keyword argument `input_shape` 381 (tuple of integers, does not include the samples axis) 382 when using this layer as the first layer in a model. 383 384 Output shape: 385 `(batch_size,) + target_shape` 386 387 Example: 388 389 ```python 390 # as first layer in a Sequential model 391 model = Sequential() 392 model.add(Reshape((3, 4), input_shape=(12,))) 393 # now: model.output_shape == (None, 3, 4) 394 # note: `None` is the batch dimension 395 396 # as intermediate layer in a Sequential model 397 model.add(Reshape((6, 2))) 398 # now: model.output_shape == (None, 6, 2) 399 400 # also supports shape inference using `-1` as dimension 401 model.add(Reshape((-1, 2, 2))) 402 # now: model.output_shape == (None, 3, 2, 2) 403 ``` 404 """ 405 406 def __init__(self, target_shape, **kwargs): 407 super(Reshape, self).__init__(**kwargs) 408 self.target_shape = tuple(target_shape) 409 410 def _fix_unknown_dimension(self, input_shape, output_shape): 411 """Find and replace a missing dimension in an output shape. 412 413 This is a near direct port of the internal Numpy function 414 `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c` 415 416 Arguments: 417 input_shape: Shape of array being reshaped 418 output_shape: Desired shape of the array with at most 419 a single -1 which indicates a dimension that should be 420 derived from the input shape. 421 422 Returns: 423 The new output shape with a -1 replaced with its computed value. 424 425 Raises: 426 ValueError: If the total array size of the output_shape is 427 different than the input_shape, or more than one unknown dimension 428 is specified. 429 """ 430 output_shape = list(output_shape) 431 msg = 'total size of new array must be unchanged' 432 433 known, unknown = 1, None 434 for index, dim in enumerate(output_shape): 435 if dim < 0: 436 if unknown is None: 437 unknown = index 438 else: 439 raise ValueError('Can only specify one unknown dimension.') 440 else: 441 known *= dim 442 443 original = np.prod(input_shape, dtype=int) 444 if unknown is not None: 445 if known == 0 or original % known != 0: 446 raise ValueError(msg) 447 output_shape[unknown] = original // known 448 elif original != known: 449 raise ValueError(msg) 450 return output_shape 451 452 def compute_output_shape(self, input_shape): 453 input_shape = tensor_shape.TensorShape(input_shape).as_list() 454 if None in input_shape[1:]: 455 output_shape = [input_shape[0]] 456 # input shape (partially) unknown? replace -1's with None's 457 output_shape += tuple(s if s != -1 else None for s in self.target_shape) 458 else: 459 output_shape = [input_shape[0]] 460 output_shape += self._fix_unknown_dimension(input_shape[1:], 461 self.target_shape) 462 return tensor_shape.TensorShape(output_shape) 463 464 def call(self, inputs): 465 return array_ops.reshape(inputs, 466 (array_ops.shape(inputs)[0],) + self.target_shape) 467 468 def get_config(self): 469 config = {'target_shape': self.target_shape} 470 base_config = super(Reshape, self).get_config() 471 return dict(list(base_config.items()) + list(config.items())) 472 473 474@keras_export('keras.layers.Permute') 475class Permute(Layer): 476 """Permutes the dimensions of the input according to a given pattern. 477 478 Useful for e.g. connecting RNNs and convnets together. 479 480 Example: 481 482 ```python 483 model = Sequential() 484 model.add(Permute((2, 1), input_shape=(10, 64))) 485 # now: model.output_shape == (None, 64, 10) 486 # note: `None` is the batch dimension 487 ``` 488 489 Arguments: 490 dims: Tuple of integers. Permutation pattern, does not include the 491 samples dimension. Indexing starts at 1. 492 For instance, `(2, 1)` permutes the first and second dimensions 493 of the input. 494 495 Input shape: 496 Arbitrary. Use the keyword argument `input_shape` 497 (tuple of integers, does not include the samples axis) 498 when using this layer as the first layer in a model. 499 500 Output shape: 501 Same as the input shape, but with the dimensions re-ordered according 502 to the specified pattern. 503 """ 504 505 def __init__(self, dims, **kwargs): 506 super(Permute, self).__init__(**kwargs) 507 self.dims = tuple(dims) 508 if sorted(dims) != list(range(1, len(dims) + 1)): 509 raise ValueError( 510 'Invalid permutation `dims` for Permute Layer: %s. ' 511 'The set of indices in `dims` must be consecutive and start from 1.' % 512 (dims,)) 513 self.input_spec = InputSpec(ndim=len(self.dims) + 1) 514 515 def compute_output_shape(self, input_shape): 516 input_shape = tensor_shape.TensorShape(input_shape).as_list() 517 output_shape = copy.copy(input_shape) 518 for i, dim in enumerate(self.dims): 519 target_dim = input_shape[dim] 520 output_shape[i + 1] = target_dim 521 return tensor_shape.TensorShape(output_shape) 522 523 def call(self, inputs): 524 return array_ops.transpose(inputs, perm=(0,) + self.dims) 525 526 def get_config(self): 527 config = {'dims': self.dims} 528 base_config = super(Permute, self).get_config() 529 return dict(list(base_config.items()) + list(config.items())) 530 531 532@keras_export('keras.layers.Flatten') 533class Flatten(Layer): 534 """Flattens the input. Does not affect the batch size. 535 536 If inputs are shaped `(batch,)` without a channel dimension, then flattening 537 adds an extra channel dimension and output shapes are `(batch, 1)`. 538 539 Arguments: 540 data_format: A string, 541 one of `channels_last` (default) or `channels_first`. 542 The ordering of the dimensions in the inputs. 543 `channels_last` corresponds to inputs with shape 544 `(batch, ..., channels)` while `channels_first` corresponds to 545 inputs with shape `(batch, channels, ...)`. 546 It defaults to the `image_data_format` value found in your 547 Keras config file at `~/.keras/keras.json`. 548 If you never set it, then it will be "channels_last". 549 550 Example: 551 552 ```python 553 model = Sequential() 554 model.add(Convolution2D(64, 3, 3, 555 border_mode='same', 556 input_shape=(3, 32, 32))) 557 # now: model.output_shape == (None, 64, 32, 32) 558 559 model.add(Flatten()) 560 # now: model.output_shape == (None, 65536) 561 ``` 562 """ 563 564 def __init__(self, data_format=None, **kwargs): 565 super(Flatten, self).__init__(**kwargs) 566 self.data_format = conv_utils.normalize_data_format(data_format) 567 self.input_spec = InputSpec(min_ndim=1) 568 569 def call(self, inputs): 570 if (self.data_format == 'channels_first' 571 and K.ndim(inputs) is not None and K.ndim(inputs) > 1): 572 permutation = [0] 573 permutation.extend([i for i in 574 range(2, K.ndim(inputs))]) 575 permutation.append(1) 576 inputs = array_ops.transpose(inputs, perm=permutation) 577 578 outputs = array_ops.reshape( 579 inputs, (tensor_shape.dimension_value(inputs.shape[0]) or 580 array_ops.shape(inputs)[0], -1)) 581 if not context.executing_eagerly(): 582 outputs.set_shape(self.compute_output_shape(inputs.get_shape())) 583 return outputs 584 585 def compute_output_shape(self, input_shape): 586 input_shape = tensor_shape.TensorShape(input_shape).as_list() 587 if not input_shape: 588 output_shape = tensor_shape.TensorShape([1]) 589 output_shape = [input_shape[0]] 590 if all(input_shape[1:]): 591 output_shape += [np.prod(input_shape[1:])] 592 else: 593 output_shape += [None] 594 return tensor_shape.TensorShape(output_shape) 595 596 def get_config(self): 597 config = {'data_format': self.data_format} 598 base_config = super(Flatten, self).get_config() 599 return dict(list(base_config.items()) + list(config.items())) 600 601 602@keras_export('keras.layers.RepeatVector') 603class RepeatVector(Layer): 604 """Repeats the input n times. 605 606 Example: 607 608 ```python 609 model = Sequential() 610 model.add(Dense(32, input_dim=32)) 611 # now: model.output_shape == (None, 32) 612 # note: `None` is the batch dimension 613 614 model.add(RepeatVector(3)) 615 # now: model.output_shape == (None, 3, 32) 616 ``` 617 618 Arguments: 619 n: Integer, repetition factor. 620 621 Input shape: 622 2D tensor of shape `(num_samples, features)`. 623 624 Output shape: 625 3D tensor of shape `(num_samples, n, features)`. 626 """ 627 628 def __init__(self, n, **kwargs): 629 super(RepeatVector, self).__init__(**kwargs) 630 self.n = n 631 self.input_spec = InputSpec(ndim=2) 632 633 def compute_output_shape(self, input_shape): 634 input_shape = tensor_shape.TensorShape(input_shape).as_list() 635 return tensor_shape.TensorShape([input_shape[0], self.n, input_shape[1]]) 636 637 def call(self, inputs): 638 return K.repeat(inputs, self.n) 639 640 def get_config(self): 641 config = {'n': self.n} 642 base_config = super(RepeatVector, self).get_config() 643 return dict(list(base_config.items()) + list(config.items())) 644 645 646@keras_export('keras.layers.Lambda') 647class Lambda(Layer): 648 """Wraps arbitrary expressions as a `Layer` object. 649 650 The `Lambda` layer exists so that aribtrary TensorFlow functions 651 can be used when constructing `Sequential` and Functional API 652 models. `Lambda` layers are best suited for simple operations or 653 quick experimentation. For more advanced use cases, subclassing 654 `keras.layers.Layer` is preferred. One reason for this is that 655 when saving a Model, `Lambda` layers are saved by serializing the 656 Python bytecode, whereas subclassed Layers are saved via overriding 657 their `get_config` method and are thus more portable. Models that rely 658 on subclassed Layers are also often easier to visualize and reason 659 about. 660 661 Examples: 662 663 ```python 664 # add a x -> x^2 layer 665 model.add(Lambda(lambda x: x ** 2)) 666 ``` 667 ```python 668 # add a layer that returns the concatenation 669 # of the positive part of the input and 670 # the opposite of the negative part 671 672 def antirectifier(x): 673 x -= K.mean(x, axis=1, keepdims=True) 674 x = K.l2_normalize(x, axis=1) 675 pos = K.relu(x) 676 neg = K.relu(-x) 677 return K.concatenate([pos, neg], axis=1) 678 679 model.add(Lambda(antirectifier)) 680 ``` 681 682 Variables can be created within a `Lambda` layer. Like with 683 other layers, these variables will be created only once and reused 684 if the `Lambda` layer is called on new inputs. If creating more 685 than one variable in a given `Lambda` instance, be sure to use 686 a different name for each variable. Note that calling sublayers 687 from within a `Lambda` is not supported. 688 689 Example of variable creation: 690 691 ```python 692 def linear_transform(x): 693 v1 = tf.Variable(1., name='multiplier') 694 v2 = tf.Variable(0., name='bias') 695 return x*v1 + v2 696 697 linear_layer = Lambda(linear_transform) 698 model.add(linear_layer) 699 model.add(keras.layers.Dense(10, activation='relu')) 700 model.add(linear_layer) # Reuses existing Variables 701 ``` 702 703 Note that creating two instances of `Lambda` using the same function 704 will *not* share Variables between the two instances. Each instance of 705 `Lambda` will create and manage its own weights. 706 707 Arguments: 708 function: The function to be evaluated. Takes input tensor as first 709 argument. 710 output_shape: Expected output shape from function. This argument can be 711 inferred if not explicitly provided. Can be a tuple or function. If a 712 tuple, it only specifies the first dimension onward; 713 sample dimension is assumed either the same as the input: `output_shape = 714 (input_shape[0], ) + output_shape` or, the input is `None` and 715 the sample dimension is also `None`: `output_shape = (None, ) + 716 output_shape` If a function, it specifies the entire shape as a function 717 of the 718 input shape: `output_shape = f(input_shape)` 719 arguments: Optional dictionary of keyword arguments to be passed to the 720 function. 721 Input shape: Arbitrary. Use the keyword argument input_shape (tuple of 722 integers, does not include the samples axis) when using this layer as the 723 first layer in a model. 724 Output shape: Specified by `output_shape` argument 725 """ 726 727 def __init__(self, function, output_shape=None, mask=None, arguments=None, 728 **kwargs): 729 super(Lambda, self).__init__(**kwargs) 730 self.function = function 731 self.arguments = arguments if arguments else {} 732 if mask is not None: 733 self.supports_masking = True 734 self.mask = mask 735 self._output_shape = output_shape 736 self._variable_dict = {} 737 # These attributes are inherited from `Layer`. 738 self._trainable_weights = [] 739 self._non_trainable_weights = [] 740 741 @tf_utils.shape_type_conversion 742 def compute_output_shape(self, input_shape): 743 if self._output_shape is None: 744 # Make use of existing autocomputation but provide Lambda-specific 745 # error message. This is always safe to run even whn the outer context 746 # is Graph mode because Lambda layers don't have side effects such as 747 # `add_loss`. 748 with context.eager_mode(): 749 try: 750 return super(Lambda, self).compute_output_shape(input_shape) 751 except NotImplementedError: 752 raise NotImplementedError( 753 'We could not automatically infer the shape of the Lambda\'s ' 754 'output. Please specify `output_shape` for this Lambda.') 755 756 if callable(self._output_shape): 757 output_shapes = self._output_shape(input_shape) 758 return tf_utils.convert_shapes(output_shapes, to_tuples=False) 759 760 # Output shapes are passed directly and don't include batch dimension. 761 input_tensor_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 762 batch_size = nest.flatten(input_tensor_shape)[0][0] if input_shape else None 763 764 def _add_batch(shape): 765 return tensor_shape.TensorShape([batch_size] + shape.as_list()) 766 767 output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False) 768 return nest.map_structure(_add_batch, output_shapes) 769 770 def call(self, inputs, mask=None): 771 arguments = self.arguments 772 if generic_utils.has_arg(self.function, 'mask'): 773 arguments['mask'] = mask 774 with variable_scope.variable_creator_scope(self._variable_creator): 775 return self.function(inputs, **arguments) 776 777 def _variable_creator(self, next_creator, **kwargs): 778 name = kwargs['name'] 779 if name in self._variable_dict: 780 return self._variable_dict[name] 781 var = next_creator(**kwargs) 782 self._variable_dict[name] = var 783 if var.trainable: 784 self._trainable_weights.append(var) 785 else: 786 self._non_trainable_weights.append(var) 787 K.track_variable(var) 788 return var 789 790 def compute_mask(self, inputs, mask=None): 791 if callable(self.mask): 792 return self.mask(inputs, mask) 793 return self.mask 794 795 def get_config(self): 796 module = self.function.__module__ 797 if isinstance(self.function, python_types.LambdaType): 798 function = generic_utils.func_dump(self.function) 799 function_type = 'lambda' 800 else: 801 function = self.function.__name__ 802 function_type = 'function' 803 804 output_shape_module = None 805 if isinstance(self._output_shape, python_types.LambdaType): 806 output_shape = generic_utils.func_dump(self._output_shape) 807 output_shape_type = 'lambda' 808 output_shape_module = self._output_shape.__module__ 809 elif callable(self._output_shape): 810 output_shape = self._output_shape.__name__ 811 output_shape_type = 'function' 812 output_shape_module = self._output_shape.__module__ 813 else: 814 output_shape = self._output_shape 815 output_shape_type = 'raw' 816 817 config = { 818 'function': function, 819 'module': module, 820 'function_type': function_type, 821 'output_shape': output_shape, 822 'output_shape_type': output_shape_type, 823 'output_shape_module': output_shape_module, 824 'arguments': self.arguments 825 } 826 base_config = super(Lambda, self).get_config() 827 return dict(list(base_config.items()) + list(config.items())) 828 829 @classmethod 830 def from_config(cls, config, custom_objects=None): 831 config = config.copy() 832 globs = globals() 833 module = config.pop('module', None) 834 if module in sys.modules: 835 globs.update(sys.modules[module].__dict__) 836 elif module is not None: 837 # Note: we don't know the name of the function if it's a lambda. 838 warnings.warn('{} is not loaded, but a Lambda layer uses it. ' 839 'It may cause errors.'.format(module) 840 , UserWarning) 841 if custom_objects: 842 globs.update(custom_objects) 843 function_type = config.pop('function_type') 844 if function_type == 'function': 845 # Simple lookup in custom objects 846 function = generic_utils.deserialize_keras_object( 847 config['function'], 848 custom_objects=custom_objects, 849 printable_module_name='function in Lambda layer') 850 elif function_type == 'lambda': 851 # Unsafe deserialization from bytecode 852 function = generic_utils.func_load(config['function'], globs=globs) 853 else: 854 raise TypeError('Unknown function type:', function_type) 855 856 output_shape_module = config.pop('output_shape_module', None) 857 if output_shape_module in sys.modules: 858 globs.update(sys.modules[output_shape_module].__dict__) 859 elif output_shape_module is not None: 860 # Note: we don't know the name of the function if it's a lambda. 861 warnings.warn('{} is not loaded, but a Lambda layer uses it. ' 862 'It may cause errors.'.format(output_shape_module) 863 , UserWarning) 864 output_shape_type = config.pop('output_shape_type') 865 if output_shape_type == 'function': 866 # Simple lookup in custom objects 867 output_shape = generic_utils.deserialize_keras_object( 868 config['output_shape'], 869 custom_objects=custom_objects, 870 printable_module_name='output_shape function in Lambda layer') 871 elif output_shape_type == 'lambda': 872 # Unsafe deserialization from bytecode 873 output_shape = generic_utils.func_load(config['output_shape'], 874 globs=globs) 875 else: 876 output_shape = config['output_shape'] 877 878 # If arguments were numpy array, they have been saved as 879 # list. We need to recover the ndarray 880 if 'arguments' in config: 881 for key in config['arguments']: 882 if isinstance(config['arguments'][key], dict): 883 arg_dict = config['arguments'][key] 884 if 'type' in arg_dict and arg_dict['type'] == 'ndarray': 885 # Overwrite the argument with its numpy translation 886 config['arguments'][key] = np.array(arg_dict['value']) 887 888 config['function'] = function 889 config['output_shape'] = output_shape 890 return cls(**config) 891 892 893@keras_export('keras.layers.Dense') 894class Dense(Layer): 895 """Just your regular densely-connected NN layer. 896 897 `Dense` implements the operation: 898 `output = activation(dot(input, kernel) + bias)` 899 where `activation` is the element-wise activation function 900 passed as the `activation` argument, `kernel` is a weights matrix 901 created by the layer, and `bias` is a bias vector created by the layer 902 (only applicable if `use_bias` is `True`). 903 904 Note: If the input to the layer has a rank greater than 2, then 905 it is flattened prior to the initial dot product with `kernel`. 906 907 Example: 908 909 ```python 910 # as first layer in a sequential model: 911 model = Sequential() 912 model.add(Dense(32, input_shape=(16,))) 913 # now the model will take as input arrays of shape (*, 16) 914 # and output arrays of shape (*, 32) 915 916 # after the first layer, you don't need to specify 917 # the size of the input anymore: 918 model.add(Dense(32)) 919 ``` 920 921 Arguments: 922 units: Positive integer, dimensionality of the output space. 923 activation: Activation function to use. 924 If you don't specify anything, no activation is applied 925 (ie. "linear" activation: `a(x) = x`). 926 use_bias: Boolean, whether the layer uses a bias vector. 927 kernel_initializer: Initializer for the `kernel` weights matrix. 928 bias_initializer: Initializer for the bias vector. 929 kernel_regularizer: Regularizer function applied to 930 the `kernel` weights matrix. 931 bias_regularizer: Regularizer function applied to the bias vector. 932 activity_regularizer: Regularizer function applied to 933 the output of the layer (its "activation").. 934 kernel_constraint: Constraint function applied to 935 the `kernel` weights matrix. 936 bias_constraint: Constraint function applied to the bias vector. 937 938 Input shape: 939 N-D tensor with shape: `(batch_size, ..., input_dim)`. 940 The most common situation would be 941 a 2D input with shape `(batch_size, input_dim)`. 942 943 Output shape: 944 N-D tensor with shape: `(batch_size, ..., units)`. 945 For instance, for a 2D input with shape `(batch_size, input_dim)`, 946 the output would have shape `(batch_size, units)`. 947 """ 948 949 def __init__(self, 950 units, 951 activation=None, 952 use_bias=True, 953 kernel_initializer='glorot_uniform', 954 bias_initializer='zeros', 955 kernel_regularizer=None, 956 bias_regularizer=None, 957 activity_regularizer=None, 958 kernel_constraint=None, 959 bias_constraint=None, 960 **kwargs): 961 if 'input_shape' not in kwargs and 'input_dim' in kwargs: 962 kwargs['input_shape'] = (kwargs.pop('input_dim'),) 963 964 super(Dense, self).__init__( 965 activity_regularizer=regularizers.get(activity_regularizer), **kwargs) 966 self.units = int(units) 967 self.activation = activations.get(activation) 968 self.use_bias = use_bias 969 self.kernel_initializer = initializers.get(kernel_initializer) 970 self.bias_initializer = initializers.get(bias_initializer) 971 self.kernel_regularizer = regularizers.get(kernel_regularizer) 972 self.bias_regularizer = regularizers.get(bias_regularizer) 973 self.kernel_constraint = constraints.get(kernel_constraint) 974 self.bias_constraint = constraints.get(bias_constraint) 975 976 self.supports_masking = True 977 self.input_spec = InputSpec(min_ndim=2) 978 979 def build(self, input_shape): 980 dtype = dtypes.as_dtype(self.dtype or K.floatx()) 981 if not (dtype.is_floating or dtype.is_complex): 982 raise TypeError('Unable to build `Dense` layer with non-floating point ' 983 'dtype %s' % (dtype,)) 984 input_shape = tensor_shape.TensorShape(input_shape) 985 if tensor_shape.dimension_value(input_shape[-1]) is None: 986 raise ValueError('The last dimension of the inputs to `Dense` ' 987 'should be defined. Found `None`.') 988 last_dim = tensor_shape.dimension_value(input_shape[-1]) 989 self.input_spec = InputSpec(min_ndim=2, 990 axes={-1: last_dim}) 991 self.kernel = self.add_weight( 992 'kernel', 993 shape=[last_dim, self.units], 994 initializer=self.kernel_initializer, 995 regularizer=self.kernel_regularizer, 996 constraint=self.kernel_constraint, 997 dtype=self.dtype, 998 trainable=True) 999 if self.use_bias: 1000 self.bias = self.add_weight( 1001 'bias', 1002 shape=[self.units,], 1003 initializer=self.bias_initializer, 1004 regularizer=self.bias_regularizer, 1005 constraint=self.bias_constraint, 1006 dtype=self.dtype, 1007 trainable=True) 1008 else: 1009 self.bias = None 1010 self.built = True 1011 1012 def call(self, inputs): 1013 inputs = ops.convert_to_tensor(inputs) 1014 rank = common_shapes.rank(inputs) 1015 if rank > 2: 1016 # Broadcasting is required for the inputs. 1017 outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]]) 1018 # Reshape the output back to the original ndim of the input. 1019 if not context.executing_eagerly(): 1020 shape = inputs.get_shape().as_list() 1021 output_shape = shape[:-1] + [self.units] 1022 outputs.set_shape(output_shape) 1023 else: 1024 # Cast the inputs to self.dtype, which is the variable dtype. We do not 1025 # cast if `should_cast_variables` is True, as in that case the variable 1026 # will be automatically casted to inputs.dtype. 1027 if not self._mixed_precision_policy.should_cast_variables: 1028 inputs = math_ops.cast(inputs, self.dtype) 1029 outputs = gen_math_ops.mat_mul(inputs, self.kernel) 1030 if self.use_bias: 1031 outputs = nn.bias_add(outputs, self.bias) 1032 if self.activation is not None: 1033 return self.activation(outputs) # pylint: disable=not-callable 1034 return outputs 1035 1036 def compute_output_shape(self, input_shape): 1037 input_shape = tensor_shape.TensorShape(input_shape) 1038 input_shape = input_shape.with_rank_at_least(2) 1039 if tensor_shape.dimension_value(input_shape[-1]) is None: 1040 raise ValueError( 1041 'The innermost dimension of input_shape must be defined, but saw: %s' 1042 % input_shape) 1043 return input_shape[:-1].concatenate(self.units) 1044 1045 def get_config(self): 1046 config = { 1047 'units': self.units, 1048 'activation': activations.serialize(self.activation), 1049 'use_bias': self.use_bias, 1050 'kernel_initializer': initializers.serialize(self.kernel_initializer), 1051 'bias_initializer': initializers.serialize(self.bias_initializer), 1052 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 1053 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 1054 'activity_regularizer': 1055 regularizers.serialize(self.activity_regularizer), 1056 'kernel_constraint': constraints.serialize(self.kernel_constraint), 1057 'bias_constraint': constraints.serialize(self.bias_constraint) 1058 } 1059 base_config = super(Dense, self).get_config() 1060 return dict(list(base_config.items()) + list(config.items())) 1061 1062 1063@keras_export('keras.layers.ActivityRegularization') 1064class ActivityRegularization(Layer): 1065 """Layer that applies an update to the cost function based input activity. 1066 1067 Arguments: 1068 l1: L1 regularization factor (positive float). 1069 l2: L2 regularization factor (positive float). 1070 1071 Input shape: 1072 Arbitrary. Use the keyword argument `input_shape` 1073 (tuple of integers, does not include the samples axis) 1074 when using this layer as the first layer in a model. 1075 1076 Output shape: 1077 Same shape as input. 1078 """ 1079 1080 def __init__(self, l1=0., l2=0., **kwargs): 1081 super(ActivityRegularization, self).__init__( 1082 activity_regularizer=regularizers.L1L2(l1=l1, l2=l2), **kwargs) 1083 self.supports_masking = True 1084 self.l1 = l1 1085 self.l2 = l2 1086 1087 def compute_output_shape(self, input_shape): 1088 return input_shape 1089 1090 def get_config(self): 1091 config = {'l1': self.l1, 'l2': self.l2} 1092 base_config = super(ActivityRegularization, self).get_config() 1093 return dict(list(base_config.items()) + list(config.items())) 1094