1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Convolutional-recurrent layers. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.python.keras import activations 25from tensorflow.python.keras import backend as K 26from tensorflow.python.keras import constraints 27from tensorflow.python.keras import initializers 28from tensorflow.python.keras import regularizers 29from tensorflow.python.keras.engine.base_layer import Layer 30from tensorflow.python.keras.engine.input_spec import InputSpec 31from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin 32from tensorflow.python.keras.layers.recurrent import RNN 33from tensorflow.python.keras.utils import conv_utils 34from tensorflow.python.keras.utils import generic_utils 35from tensorflow.python.keras.utils import tf_utils 36from tensorflow.python.ops import array_ops 37from tensorflow.python.util.tf_export import keras_export 38 39 40class ConvRNN2D(RNN): 41 """Base class for convolutional-recurrent layers. 42 43 Args: 44 cell: A RNN cell instance. A RNN cell is a class that has: 45 - a `call(input_at_t, states_at_t)` method, returning 46 `(output_at_t, states_at_t_plus_1)`. The call method of the 47 cell can also take the optional argument `constants`, see 48 section "Note on passing external constants" below. 49 - a `state_size` attribute. This can be a single integer 50 (single state) in which case it is 51 the number of channels of the recurrent state 52 (which should be the same as the number of channels of the cell 53 output). This can also be a list/tuple of integers 54 (one size per state). In this case, the first entry 55 (`state_size[0]`) should be the same as 56 the size of the cell output. 57 return_sequences: Boolean. Whether to return the last output. 58 in the output sequence, or the full sequence. 59 return_state: Boolean. Whether to return the last state 60 in addition to the output. 61 go_backwards: Boolean (default False). 62 If True, process the input sequence backwards and return the 63 reversed sequence. 64 stateful: Boolean (default False). If True, the last state 65 for each sample at index i in a batch will be used as initial 66 state for the sample of index i in the following batch. 67 input_shape: Use this argument to specify the shape of the 68 input when this layer is the first one in a model. 69 70 Call arguments: 71 inputs: A 5D tensor. 72 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 73 a given timestep should be masked. 74 training: Python boolean indicating whether the layer should behave in 75 training mode or in inference mode. This argument is passed to the cell 76 when calling it. This is for use with cells that use dropout. 77 initial_state: List of initial state tensors to be passed to the first 78 call of the cell. 79 constants: List of constant tensors to be passed to the cell at each 80 timestep. 81 82 Input shape: 83 5D tensor with shape: 84 `(samples, timesteps, channels, rows, cols)` 85 if data_format='channels_first' or 5D tensor with shape: 86 `(samples, timesteps, rows, cols, channels)` 87 if data_format='channels_last'. 88 89 Output shape: 90 - If `return_state`: a list of tensors. The first tensor is 91 the output. The remaining tensors are the last states, 92 each 4D tensor with shape: 93 `(samples, filters, new_rows, new_cols)` 94 if data_format='channels_first' 95 or 4D tensor with shape: 96 `(samples, new_rows, new_cols, filters)` 97 if data_format='channels_last'. 98 `rows` and `cols` values might have changed due to padding. 99 - If `return_sequences`: 5D tensor with shape: 100 `(samples, timesteps, filters, new_rows, new_cols)` 101 if data_format='channels_first' 102 or 5D tensor with shape: 103 `(samples, timesteps, new_rows, new_cols, filters)` 104 if data_format='channels_last'. 105 - Else, 4D tensor with shape: 106 `(samples, filters, new_rows, new_cols)` 107 if data_format='channels_first' 108 or 4D tensor with shape: 109 `(samples, new_rows, new_cols, filters)` 110 if data_format='channels_last'. 111 112 Masking: 113 This layer supports masking for input data with a variable number 114 of timesteps. 115 116 Note on using statefulness in RNNs: 117 You can set RNN layers to be 'stateful', which means that the states 118 computed for the samples in one batch will be reused as initial states 119 for the samples in the next batch. This assumes a one-to-one mapping 120 between samples in different successive batches. 121 To enable statefulness: 122 - Specify `stateful=True` in the layer constructor. 123 - Specify a fixed batch size for your model, by passing 124 - If sequential model: 125 `batch_input_shape=(...)` to the first layer in your model. 126 - If functional model with 1 or more Input layers: 127 `batch_shape=(...)` to all the first layers in your model. 128 This is the expected shape of your inputs 129 *including the batch size*. 130 It should be a tuple of integers, 131 e.g. `(32, 10, 100, 100, 32)`. 132 Note that the number of rows and columns should be specified 133 too. 134 - Specify `shuffle=False` when calling fit(). 135 To reset the states of your model, call `.reset_states()` on either 136 a specific layer, or on your entire model. 137 138 Note on specifying the initial state of RNNs: 139 You can specify the initial state of RNN layers symbolically by 140 calling them with the keyword argument `initial_state`. The value of 141 `initial_state` should be a tensor or list of tensors representing 142 the initial state of the RNN layer. 143 You can specify the initial state of RNN layers numerically by 144 calling `reset_states` with the keyword argument `states`. The value of 145 `states` should be a numpy array or list of numpy arrays representing 146 the initial state of the RNN layer. 147 148 Note on passing external constants to RNNs: 149 You can pass "external" constants to the cell using the `constants` 150 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This 151 requires that the `cell.call` method accepts the same keyword argument 152 `constants`. Such constants can be used to condition the cell 153 transformation on additional static inputs (not changing over time), 154 a.k.a. an attention mechanism. 155 """ 156 157 def __init__(self, 158 cell, 159 return_sequences=False, 160 return_state=False, 161 go_backwards=False, 162 stateful=False, 163 unroll=False, 164 **kwargs): 165 if unroll: 166 raise TypeError('Unrolling isn\'t possible with ' 167 'convolutional RNNs.') 168 if isinstance(cell, (list, tuple)): 169 # The StackedConvRNN2DCells isn't implemented yet. 170 raise TypeError('It is not possible at the moment to' 171 'stack convolutional cells.') 172 super(ConvRNN2D, self).__init__(cell, 173 return_sequences, 174 return_state, 175 go_backwards, 176 stateful, 177 unroll, 178 **kwargs) 179 self.input_spec = [InputSpec(ndim=5)] 180 self.states = None 181 self._num_constants = None 182 183 @tf_utils.shape_type_conversion 184 def compute_output_shape(self, input_shape): 185 if isinstance(input_shape, list): 186 input_shape = input_shape[0] 187 188 cell = self.cell 189 if cell.data_format == 'channels_first': 190 rows = input_shape[3] 191 cols = input_shape[4] 192 elif cell.data_format == 'channels_last': 193 rows = input_shape[2] 194 cols = input_shape[3] 195 rows = conv_utils.conv_output_length(rows, 196 cell.kernel_size[0], 197 padding=cell.padding, 198 stride=cell.strides[0], 199 dilation=cell.dilation_rate[0]) 200 cols = conv_utils.conv_output_length(cols, 201 cell.kernel_size[1], 202 padding=cell.padding, 203 stride=cell.strides[1], 204 dilation=cell.dilation_rate[1]) 205 206 if cell.data_format == 'channels_first': 207 output_shape = input_shape[:2] + (cell.filters, rows, cols) 208 elif cell.data_format == 'channels_last': 209 output_shape = input_shape[:2] + (rows, cols, cell.filters) 210 211 if not self.return_sequences: 212 output_shape = output_shape[:1] + output_shape[2:] 213 214 if self.return_state: 215 output_shape = [output_shape] 216 if cell.data_format == 'channels_first': 217 output_shape += [(input_shape[0], cell.filters, rows, cols) 218 for _ in range(2)] 219 elif cell.data_format == 'channels_last': 220 output_shape += [(input_shape[0], rows, cols, cell.filters) 221 for _ in range(2)] 222 return output_shape 223 224 @tf_utils.shape_type_conversion 225 def build(self, input_shape): 226 # Note input_shape will be list of shapes of initial states and 227 # constants if these are passed in __call__. 228 if self._num_constants is not None: 229 constants_shape = input_shape[-self._num_constants:] # pylint: disable=E1130 230 else: 231 constants_shape = None 232 233 if isinstance(input_shape, list): 234 input_shape = input_shape[0] 235 236 batch_size = input_shape[0] if self.stateful else None 237 self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5]) 238 239 # allow cell (if layer) to build before we set or validate state_spec 240 if isinstance(self.cell, Layer): 241 step_input_shape = (input_shape[0],) + input_shape[2:] 242 if constants_shape is not None: 243 self.cell.build([step_input_shape] + constants_shape) 244 else: 245 self.cell.build(step_input_shape) 246 247 # set or validate state_spec 248 if hasattr(self.cell.state_size, '__len__'): 249 state_size = list(self.cell.state_size) 250 else: 251 state_size = [self.cell.state_size] 252 253 if self.state_spec is not None: 254 # initial_state was passed in call, check compatibility 255 if self.cell.data_format == 'channels_first': 256 ch_dim = 1 257 elif self.cell.data_format == 'channels_last': 258 ch_dim = 3 259 if [spec.shape[ch_dim] for spec in self.state_spec] != state_size: 260 raise ValueError( 261 'An initial_state was passed that is not compatible with ' 262 '`cell.state_size`. Received `state_spec`={}; ' 263 'However `cell.state_size` is ' 264 '{}'.format([spec.shape for spec in self.state_spec], 265 self.cell.state_size)) 266 else: 267 if self.cell.data_format == 'channels_first': 268 self.state_spec = [InputSpec(shape=(None, dim, None, None)) 269 for dim in state_size] 270 elif self.cell.data_format == 'channels_last': 271 self.state_spec = [InputSpec(shape=(None, None, None, dim)) 272 for dim in state_size] 273 if self.stateful: 274 self.reset_states() 275 self.built = True 276 277 def get_initial_state(self, inputs): 278 # (samples, timesteps, rows, cols, filters) 279 initial_state = K.zeros_like(inputs) 280 # (samples, rows, cols, filters) 281 initial_state = K.sum(initial_state, axis=1) 282 shape = list(self.cell.kernel_shape) 283 shape[-1] = self.cell.filters 284 initial_state = self.cell.input_conv(initial_state, 285 array_ops.zeros(tuple(shape), 286 initial_state.dtype), 287 padding=self.cell.padding) 288 289 if hasattr(self.cell.state_size, '__len__'): 290 return [initial_state for _ in self.cell.state_size] 291 else: 292 return [initial_state] 293 294 def call(self, 295 inputs, 296 mask=None, 297 training=None, 298 initial_state=None, 299 constants=None): 300 # note that the .build() method of subclasses MUST define 301 # self.input_spec and self.state_spec with complete input shapes. 302 inputs, initial_state, constants = self._process_inputs( 303 inputs, initial_state, constants) 304 305 if isinstance(mask, list): 306 mask = mask[0] 307 timesteps = K.int_shape(inputs)[1] 308 309 kwargs = {} 310 if generic_utils.has_arg(self.cell.call, 'training'): 311 kwargs['training'] = training 312 313 if constants: 314 if not generic_utils.has_arg(self.cell.call, 'constants'): 315 raise ValueError('RNN cell does not support constants') 316 317 def step(inputs, states): 318 constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type 319 states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type 320 return self.cell.call(inputs, states, constants=constants, **kwargs) 321 else: 322 def step(inputs, states): 323 return self.cell.call(inputs, states, **kwargs) 324 325 last_output, outputs, states = K.rnn(step, 326 inputs, 327 initial_state, 328 constants=constants, 329 go_backwards=self.go_backwards, 330 mask=mask, 331 input_length=timesteps) 332 if self.stateful: 333 updates = [ 334 K.update(self_state, state) 335 for self_state, state in zip(self.states, states) 336 ] 337 self.add_update(updates) 338 339 if self.return_sequences: 340 output = outputs 341 else: 342 output = last_output 343 344 if self.return_state: 345 if not isinstance(states, (list, tuple)): 346 states = [states] 347 else: 348 states = list(states) 349 return [output] + states 350 else: 351 return output 352 353 def reset_states(self, states=None): 354 if not self.stateful: 355 raise AttributeError('Layer must be stateful.') 356 input_shape = self.input_spec[0].shape 357 state_shape = self.compute_output_shape(input_shape) 358 if self.return_state: 359 state_shape = state_shape[0] 360 if self.return_sequences: 361 state_shape = state_shape[:1].concatenate(state_shape[2:]) 362 if None in state_shape: 363 raise ValueError('If a RNN is stateful, it needs to know ' 364 'its batch size. Specify the batch size ' 365 'of your input tensors: \n' 366 '- If using a Sequential model, ' 367 'specify the batch size by passing ' 368 'a `batch_input_shape` ' 369 'argument to your first layer.\n' 370 '- If using the functional API, specify ' 371 'the time dimension by passing a ' 372 '`batch_shape` argument to your Input layer.\n' 373 'The same thing goes for the number of rows and ' 374 'columns.') 375 376 # helper function 377 def get_tuple_shape(nb_channels): 378 result = list(state_shape) 379 if self.cell.data_format == 'channels_first': 380 result[1] = nb_channels 381 elif self.cell.data_format == 'channels_last': 382 result[3] = nb_channels 383 else: 384 raise KeyError 385 return tuple(result) 386 387 # initialize state if None 388 if self.states[0] is None: 389 if hasattr(self.cell.state_size, '__len__'): 390 self.states = [K.zeros(get_tuple_shape(dim)) 391 for dim in self.cell.state_size] 392 else: 393 self.states = [K.zeros(get_tuple_shape(self.cell.state_size))] 394 elif states is None: 395 if hasattr(self.cell.state_size, '__len__'): 396 for state, dim in zip(self.states, self.cell.state_size): 397 K.set_value(state, np.zeros(get_tuple_shape(dim))) 398 else: 399 K.set_value(self.states[0], 400 np.zeros(get_tuple_shape(self.cell.state_size))) 401 else: 402 if not isinstance(states, (list, tuple)): 403 states = [states] 404 if len(states) != len(self.states): 405 raise ValueError('Layer ' + self.name + ' expects ' + 406 str(len(self.states)) + ' states, ' + 407 'but it received ' + str(len(states)) + 408 ' state values. Input received: ' + str(states)) 409 for index, (value, state) in enumerate(zip(states, self.states)): 410 if hasattr(self.cell.state_size, '__len__'): 411 dim = self.cell.state_size[index] 412 else: 413 dim = self.cell.state_size 414 if value.shape != get_tuple_shape(dim): 415 raise ValueError('State ' + str(index) + 416 ' is incompatible with layer ' + 417 self.name + ': expected shape=' + 418 str(get_tuple_shape(dim)) + 419 ', found shape=' + str(value.shape)) 420 # TODO(anjalisridhar): consider batch calls to `set_value`. 421 K.set_value(state, value) 422 423 424class ConvLSTM2DCell(DropoutRNNCellMixin, Layer): 425 """Cell class for the ConvLSTM2D layer. 426 427 Args: 428 filters: Integer, the dimensionality of the output space 429 (i.e. the number of output filters in the convolution). 430 kernel_size: An integer or tuple/list of n integers, specifying the 431 dimensions of the convolution window. 432 strides: An integer or tuple/list of n integers, 433 specifying the strides of the convolution. 434 Specifying any stride value != 1 is incompatible with specifying 435 any `dilation_rate` value != 1. 436 padding: One of `"valid"` or `"same"` (case-insensitive). 437 `"valid"` means no padding. `"same"` results in padding evenly to 438 the left/right or up/down of the input such that output has the same 439 height/width dimension as the input. 440 data_format: A string, 441 one of `channels_last` (default) or `channels_first`. 442 It defaults to the `image_data_format` value found in your 443 Keras config file at `~/.keras/keras.json`. 444 If you never set it, then it will be "channels_last". 445 dilation_rate: An integer or tuple/list of n integers, specifying 446 the dilation rate to use for dilated convolution. 447 Currently, specifying any `dilation_rate` value != 1 is 448 incompatible with specifying any `strides` value != 1. 449 activation: Activation function to use. 450 If you don't specify anything, no activation is applied 451 (ie. "linear" activation: `a(x) = x`). 452 recurrent_activation: Activation function to use 453 for the recurrent step. 454 use_bias: Boolean, whether the layer uses a bias vector. 455 kernel_initializer: Initializer for the `kernel` weights matrix, 456 used for the linear transformation of the inputs. 457 recurrent_initializer: Initializer for the `recurrent_kernel` 458 weights matrix, 459 used for the linear transformation of the recurrent state. 460 bias_initializer: Initializer for the bias vector. 461 unit_forget_bias: Boolean. 462 If True, add 1 to the bias of the forget gate at initialization. 463 Use in combination with `bias_initializer="zeros"`. 464 This is recommended in [Jozefowicz et al., 2015]( 465 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 466 kernel_regularizer: Regularizer function applied to 467 the `kernel` weights matrix. 468 recurrent_regularizer: Regularizer function applied to 469 the `recurrent_kernel` weights matrix. 470 bias_regularizer: Regularizer function applied to the bias vector. 471 kernel_constraint: Constraint function applied to 472 the `kernel` weights matrix. 473 recurrent_constraint: Constraint function applied to 474 the `recurrent_kernel` weights matrix. 475 bias_constraint: Constraint function applied to the bias vector. 476 dropout: Float between 0 and 1. 477 Fraction of the units to drop for 478 the linear transformation of the inputs. 479 recurrent_dropout: Float between 0 and 1. 480 Fraction of the units to drop for 481 the linear transformation of the recurrent state. 482 483 Call arguments: 484 inputs: A 4D tensor. 485 states: List of state tensors corresponding to the previous timestep. 486 training: Python boolean indicating whether the layer should behave in 487 training mode or in inference mode. Only relevant when `dropout` or 488 `recurrent_dropout` is used. 489 """ 490 491 def __init__(self, 492 filters, 493 kernel_size, 494 strides=(1, 1), 495 padding='valid', 496 data_format=None, 497 dilation_rate=(1, 1), 498 activation='tanh', 499 recurrent_activation='hard_sigmoid', 500 use_bias=True, 501 kernel_initializer='glorot_uniform', 502 recurrent_initializer='orthogonal', 503 bias_initializer='zeros', 504 unit_forget_bias=True, 505 kernel_regularizer=None, 506 recurrent_regularizer=None, 507 bias_regularizer=None, 508 kernel_constraint=None, 509 recurrent_constraint=None, 510 bias_constraint=None, 511 dropout=0., 512 recurrent_dropout=0., 513 **kwargs): 514 super(ConvLSTM2DCell, self).__init__(**kwargs) 515 self.filters = filters 516 self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') 517 self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') 518 self.padding = conv_utils.normalize_padding(padding) 519 self.data_format = conv_utils.normalize_data_format(data_format) 520 self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 521 'dilation_rate') 522 self.activation = activations.get(activation) 523 self.recurrent_activation = activations.get(recurrent_activation) 524 self.use_bias = use_bias 525 526 self.kernel_initializer = initializers.get(kernel_initializer) 527 self.recurrent_initializer = initializers.get(recurrent_initializer) 528 self.bias_initializer = initializers.get(bias_initializer) 529 self.unit_forget_bias = unit_forget_bias 530 531 self.kernel_regularizer = regularizers.get(kernel_regularizer) 532 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 533 self.bias_regularizer = regularizers.get(bias_regularizer) 534 535 self.kernel_constraint = constraints.get(kernel_constraint) 536 self.recurrent_constraint = constraints.get(recurrent_constraint) 537 self.bias_constraint = constraints.get(bias_constraint) 538 539 self.dropout = min(1., max(0., dropout)) 540 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 541 self.state_size = (self.filters, self.filters) 542 543 def build(self, input_shape): 544 545 if self.data_format == 'channels_first': 546 channel_axis = 1 547 else: 548 channel_axis = -1 549 if input_shape[channel_axis] is None: 550 raise ValueError('The channel dimension of the inputs ' 551 'should be defined. Found `None`.') 552 input_dim = input_shape[channel_axis] 553 kernel_shape = self.kernel_size + (input_dim, self.filters * 4) 554 self.kernel_shape = kernel_shape 555 recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4) 556 557 self.kernel = self.add_weight(shape=kernel_shape, 558 initializer=self.kernel_initializer, 559 name='kernel', 560 regularizer=self.kernel_regularizer, 561 constraint=self.kernel_constraint) 562 self.recurrent_kernel = self.add_weight( 563 shape=recurrent_kernel_shape, 564 initializer=self.recurrent_initializer, 565 name='recurrent_kernel', 566 regularizer=self.recurrent_regularizer, 567 constraint=self.recurrent_constraint) 568 569 if self.use_bias: 570 if self.unit_forget_bias: 571 572 def bias_initializer(_, *args, **kwargs): 573 return K.concatenate([ 574 self.bias_initializer((self.filters,), *args, **kwargs), 575 initializers.get('ones')((self.filters,), *args, **kwargs), 576 self.bias_initializer((self.filters * 2,), *args, **kwargs), 577 ]) 578 else: 579 bias_initializer = self.bias_initializer 580 self.bias = self.add_weight( 581 shape=(self.filters * 4,), 582 name='bias', 583 initializer=bias_initializer, 584 regularizer=self.bias_regularizer, 585 constraint=self.bias_constraint) 586 else: 587 self.bias = None 588 self.built = True 589 590 def call(self, inputs, states, training=None): 591 h_tm1 = states[0] # previous memory state 592 c_tm1 = states[1] # previous carry state 593 594 # dropout matrices for input units 595 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 596 # dropout matrices for recurrent units 597 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 598 h_tm1, training, count=4) 599 600 if 0 < self.dropout < 1.: 601 inputs_i = inputs * dp_mask[0] 602 inputs_f = inputs * dp_mask[1] 603 inputs_c = inputs * dp_mask[2] 604 inputs_o = inputs * dp_mask[3] 605 else: 606 inputs_i = inputs 607 inputs_f = inputs 608 inputs_c = inputs 609 inputs_o = inputs 610 611 if 0 < self.recurrent_dropout < 1.: 612 h_tm1_i = h_tm1 * rec_dp_mask[0] 613 h_tm1_f = h_tm1 * rec_dp_mask[1] 614 h_tm1_c = h_tm1 * rec_dp_mask[2] 615 h_tm1_o = h_tm1 * rec_dp_mask[3] 616 else: 617 h_tm1_i = h_tm1 618 h_tm1_f = h_tm1 619 h_tm1_c = h_tm1 620 h_tm1_o = h_tm1 621 622 (kernel_i, kernel_f, 623 kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3) 624 (recurrent_kernel_i, 625 recurrent_kernel_f, 626 recurrent_kernel_c, 627 recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3) 628 629 if self.use_bias: 630 bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4) 631 else: 632 bias_i, bias_f, bias_c, bias_o = None, None, None, None 633 634 x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding) 635 x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding) 636 x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding) 637 x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding) 638 h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i) 639 h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f) 640 h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c) 641 h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o) 642 643 i = self.recurrent_activation(x_i + h_i) 644 f = self.recurrent_activation(x_f + h_f) 645 c = f * c_tm1 + i * self.activation(x_c + h_c) 646 o = self.recurrent_activation(x_o + h_o) 647 h = o * self.activation(c) 648 return h, [h, c] 649 650 def input_conv(self, x, w, b=None, padding='valid'): 651 conv_out = K.conv2d(x, w, strides=self.strides, 652 padding=padding, 653 data_format=self.data_format, 654 dilation_rate=self.dilation_rate) 655 if b is not None: 656 conv_out = K.bias_add(conv_out, b, 657 data_format=self.data_format) 658 return conv_out 659 660 def recurrent_conv(self, x, w): 661 conv_out = K.conv2d(x, w, strides=(1, 1), 662 padding='same', 663 data_format=self.data_format) 664 return conv_out 665 666 def get_config(self): 667 config = {'filters': self.filters, 668 'kernel_size': self.kernel_size, 669 'strides': self.strides, 670 'padding': self.padding, 671 'data_format': self.data_format, 672 'dilation_rate': self.dilation_rate, 673 'activation': activations.serialize(self.activation), 674 'recurrent_activation': activations.serialize( 675 self.recurrent_activation), 676 'use_bias': self.use_bias, 677 'kernel_initializer': initializers.serialize( 678 self.kernel_initializer), 679 'recurrent_initializer': initializers.serialize( 680 self.recurrent_initializer), 681 'bias_initializer': initializers.serialize(self.bias_initializer), 682 'unit_forget_bias': self.unit_forget_bias, 683 'kernel_regularizer': regularizers.serialize( 684 self.kernel_regularizer), 685 'recurrent_regularizer': regularizers.serialize( 686 self.recurrent_regularizer), 687 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 688 'kernel_constraint': constraints.serialize( 689 self.kernel_constraint), 690 'recurrent_constraint': constraints.serialize( 691 self.recurrent_constraint), 692 'bias_constraint': constraints.serialize(self.bias_constraint), 693 'dropout': self.dropout, 694 'recurrent_dropout': self.recurrent_dropout} 695 base_config = super(ConvLSTM2DCell, self).get_config() 696 return dict(list(base_config.items()) + list(config.items())) 697 698 699@keras_export('keras.layers.ConvLSTM2D') 700class ConvLSTM2D(ConvRNN2D): 701 """Convolutional LSTM. 702 703 It is similar to an LSTM layer, but the input transformations 704 and recurrent transformations are both convolutional. 705 706 Args: 707 filters: Integer, the dimensionality of the output space 708 (i.e. the number of output filters in the convolution). 709 kernel_size: An integer or tuple/list of n integers, specifying the 710 dimensions of the convolution window. 711 strides: An integer or tuple/list of n integers, 712 specifying the strides of the convolution. 713 Specifying any stride value != 1 is incompatible with specifying 714 any `dilation_rate` value != 1. 715 padding: One of `"valid"` or `"same"` (case-insensitive). 716 `"valid"` means no padding. `"same"` results in padding evenly to 717 the left/right or up/down of the input such that output has the same 718 height/width dimension as the input. 719 data_format: A string, 720 one of `channels_last` (default) or `channels_first`. 721 The ordering of the dimensions in the inputs. 722 `channels_last` corresponds to inputs with shape 723 `(batch, time, ..., channels)` 724 while `channels_first` corresponds to 725 inputs with shape `(batch, time, channels, ...)`. 726 It defaults to the `image_data_format` value found in your 727 Keras config file at `~/.keras/keras.json`. 728 If you never set it, then it will be "channels_last". 729 dilation_rate: An integer or tuple/list of n integers, specifying 730 the dilation rate to use for dilated convolution. 731 Currently, specifying any `dilation_rate` value != 1 is 732 incompatible with specifying any `strides` value != 1. 733 activation: Activation function to use. 734 By default hyperbolic tangent activation function is applied 735 (`tanh(x)`). 736 recurrent_activation: Activation function to use 737 for the recurrent step. 738 use_bias: Boolean, whether the layer uses a bias vector. 739 kernel_initializer: Initializer for the `kernel` weights matrix, 740 used for the linear transformation of the inputs. 741 recurrent_initializer: Initializer for the `recurrent_kernel` 742 weights matrix, 743 used for the linear transformation of the recurrent state. 744 bias_initializer: Initializer for the bias vector. 745 unit_forget_bias: Boolean. 746 If True, add 1 to the bias of the forget gate at initialization. 747 Use in combination with `bias_initializer="zeros"`. 748 This is recommended in [Jozefowicz et al., 2015]( 749 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 750 kernel_regularizer: Regularizer function applied to 751 the `kernel` weights matrix. 752 recurrent_regularizer: Regularizer function applied to 753 the `recurrent_kernel` weights matrix. 754 bias_regularizer: Regularizer function applied to the bias vector. 755 activity_regularizer: Regularizer function applied to. 756 kernel_constraint: Constraint function applied to 757 the `kernel` weights matrix. 758 recurrent_constraint: Constraint function applied to 759 the `recurrent_kernel` weights matrix. 760 bias_constraint: Constraint function applied to the bias vector. 761 return_sequences: Boolean. Whether to return the last output 762 in the output sequence, or the full sequence. (default False) 763 return_state: Boolean Whether to return the last state 764 in addition to the output. (default False) 765 go_backwards: Boolean (default False). 766 If True, process the input sequence backwards. 767 stateful: Boolean (default False). If True, the last state 768 for each sample at index i in a batch will be used as initial 769 state for the sample of index i in the following batch. 770 dropout: Float between 0 and 1. 771 Fraction of the units to drop for 772 the linear transformation of the inputs. 773 recurrent_dropout: Float between 0 and 1. 774 Fraction of the units to drop for 775 the linear transformation of the recurrent state. 776 777 Call arguments: 778 inputs: A 5D tensor. 779 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 780 a given timestep should be masked. 781 training: Python boolean indicating whether the layer should behave in 782 training mode or in inference mode. This argument is passed to the cell 783 when calling it. This is only relevant if `dropout` or `recurrent_dropout` 784 are set. 785 initial_state: List of initial state tensors to be passed to the first 786 call of the cell. 787 788 Input shape: 789 - If data_format='channels_first' 790 5D tensor with shape: 791 `(samples, time, channels, rows, cols)` 792 - If data_format='channels_last' 793 5D tensor with shape: 794 `(samples, time, rows, cols, channels)` 795 796 Output shape: 797 - If `return_state`: a list of tensors. The first tensor is 798 the output. The remaining tensors are the last states, 799 each 4D tensor with shape: 800 `(samples, filters, new_rows, new_cols)` 801 if data_format='channels_first' 802 or 4D tensor with shape: 803 `(samples, new_rows, new_cols, filters)` 804 if data_format='channels_last'. 805 `rows` and `cols` values might have changed due to padding. 806 - If `return_sequences`: 5D tensor with shape: 807 `(samples, timesteps, filters, new_rows, new_cols)` 808 if data_format='channels_first' 809 or 5D tensor with shape: 810 `(samples, timesteps, new_rows, new_cols, filters)` 811 if data_format='channels_last'. 812 - Else, 4D tensor with shape: 813 `(samples, filters, new_rows, new_cols)` 814 if data_format='channels_first' 815 or 4D tensor with shape: 816 `(samples, new_rows, new_cols, filters)` 817 if data_format='channels_last'. 818 819 Raises: 820 ValueError: in case of invalid constructor arguments. 821 822 References: 823 - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1) 824 (the current implementation does not include the feedback loop on the 825 cells output). 826 """ 827 828 def __init__(self, 829 filters, 830 kernel_size, 831 strides=(1, 1), 832 padding='valid', 833 data_format=None, 834 dilation_rate=(1, 1), 835 activation='tanh', 836 recurrent_activation='hard_sigmoid', 837 use_bias=True, 838 kernel_initializer='glorot_uniform', 839 recurrent_initializer='orthogonal', 840 bias_initializer='zeros', 841 unit_forget_bias=True, 842 kernel_regularizer=None, 843 recurrent_regularizer=None, 844 bias_regularizer=None, 845 activity_regularizer=None, 846 kernel_constraint=None, 847 recurrent_constraint=None, 848 bias_constraint=None, 849 return_sequences=False, 850 return_state=False, 851 go_backwards=False, 852 stateful=False, 853 dropout=0., 854 recurrent_dropout=0., 855 **kwargs): 856 cell = ConvLSTM2DCell(filters=filters, 857 kernel_size=kernel_size, 858 strides=strides, 859 padding=padding, 860 data_format=data_format, 861 dilation_rate=dilation_rate, 862 activation=activation, 863 recurrent_activation=recurrent_activation, 864 use_bias=use_bias, 865 kernel_initializer=kernel_initializer, 866 recurrent_initializer=recurrent_initializer, 867 bias_initializer=bias_initializer, 868 unit_forget_bias=unit_forget_bias, 869 kernel_regularizer=kernel_regularizer, 870 recurrent_regularizer=recurrent_regularizer, 871 bias_regularizer=bias_regularizer, 872 kernel_constraint=kernel_constraint, 873 recurrent_constraint=recurrent_constraint, 874 bias_constraint=bias_constraint, 875 dropout=dropout, 876 recurrent_dropout=recurrent_dropout, 877 dtype=kwargs.get('dtype')) 878 super(ConvLSTM2D, self).__init__(cell, 879 return_sequences=return_sequences, 880 return_state=return_state, 881 go_backwards=go_backwards, 882 stateful=stateful, 883 **kwargs) 884 self.activity_regularizer = regularizers.get(activity_regularizer) 885 886 def call(self, inputs, mask=None, training=None, initial_state=None): 887 return super(ConvLSTM2D, self).call(inputs, 888 mask=mask, 889 training=training, 890 initial_state=initial_state) 891 892 @property 893 def filters(self): 894 return self.cell.filters 895 896 @property 897 def kernel_size(self): 898 return self.cell.kernel_size 899 900 @property 901 def strides(self): 902 return self.cell.strides 903 904 @property 905 def padding(self): 906 return self.cell.padding 907 908 @property 909 def data_format(self): 910 return self.cell.data_format 911 912 @property 913 def dilation_rate(self): 914 return self.cell.dilation_rate 915 916 @property 917 def activation(self): 918 return self.cell.activation 919 920 @property 921 def recurrent_activation(self): 922 return self.cell.recurrent_activation 923 924 @property 925 def use_bias(self): 926 return self.cell.use_bias 927 928 @property 929 def kernel_initializer(self): 930 return self.cell.kernel_initializer 931 932 @property 933 def recurrent_initializer(self): 934 return self.cell.recurrent_initializer 935 936 @property 937 def bias_initializer(self): 938 return self.cell.bias_initializer 939 940 @property 941 def unit_forget_bias(self): 942 return self.cell.unit_forget_bias 943 944 @property 945 def kernel_regularizer(self): 946 return self.cell.kernel_regularizer 947 948 @property 949 def recurrent_regularizer(self): 950 return self.cell.recurrent_regularizer 951 952 @property 953 def bias_regularizer(self): 954 return self.cell.bias_regularizer 955 956 @property 957 def kernel_constraint(self): 958 return self.cell.kernel_constraint 959 960 @property 961 def recurrent_constraint(self): 962 return self.cell.recurrent_constraint 963 964 @property 965 def bias_constraint(self): 966 return self.cell.bias_constraint 967 968 @property 969 def dropout(self): 970 return self.cell.dropout 971 972 @property 973 def recurrent_dropout(self): 974 return self.cell.recurrent_dropout 975 976 def get_config(self): 977 config = {'filters': self.filters, 978 'kernel_size': self.kernel_size, 979 'strides': self.strides, 980 'padding': self.padding, 981 'data_format': self.data_format, 982 'dilation_rate': self.dilation_rate, 983 'activation': activations.serialize(self.activation), 984 'recurrent_activation': activations.serialize( 985 self.recurrent_activation), 986 'use_bias': self.use_bias, 987 'kernel_initializer': initializers.serialize( 988 self.kernel_initializer), 989 'recurrent_initializer': initializers.serialize( 990 self.recurrent_initializer), 991 'bias_initializer': initializers.serialize(self.bias_initializer), 992 'unit_forget_bias': self.unit_forget_bias, 993 'kernel_regularizer': regularizers.serialize( 994 self.kernel_regularizer), 995 'recurrent_regularizer': regularizers.serialize( 996 self.recurrent_regularizer), 997 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 998 'activity_regularizer': regularizers.serialize( 999 self.activity_regularizer), 1000 'kernel_constraint': constraints.serialize( 1001 self.kernel_constraint), 1002 'recurrent_constraint': constraints.serialize( 1003 self.recurrent_constraint), 1004 'bias_constraint': constraints.serialize(self.bias_constraint), 1005 'dropout': self.dropout, 1006 'recurrent_dropout': self.recurrent_dropout} 1007 base_config = super(ConvLSTM2D, self).get_config() 1008 del base_config['cell'] 1009 return dict(list(base_config.items()) + list(config.items())) 1010 1011 @classmethod 1012 def from_config(cls, config): 1013 return cls(**config) 1014