1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Convolutional-recurrent layers.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.keras import activations
25from tensorflow.python.keras import backend as K
26from tensorflow.python.keras import constraints
27from tensorflow.python.keras import initializers
28from tensorflow.python.keras import regularizers
29from tensorflow.python.keras.engine.base_layer import Layer
30from tensorflow.python.keras.engine.input_spec import InputSpec
31from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
32from tensorflow.python.keras.layers.recurrent import RNN
33from tensorflow.python.keras.utils import conv_utils
34from tensorflow.python.keras.utils import generic_utils
35from tensorflow.python.keras.utils import tf_utils
36from tensorflow.python.ops import array_ops
37from tensorflow.python.util.tf_export import keras_export
38
39
40class ConvRNN2D(RNN):
41  """Base class for convolutional-recurrent layers.
42
43  Args:
44    cell: A RNN cell instance. A RNN cell is a class that has:
45      - a `call(input_at_t, states_at_t)` method, returning
46        `(output_at_t, states_at_t_plus_1)`. The call method of the
47        cell can also take the optional argument `constants`, see
48        section "Note on passing external constants" below.
49      - a `state_size` attribute. This can be a single integer
50        (single state) in which case it is
51        the number of channels of the recurrent state
52        (which should be the same as the number of channels of the cell
53        output). This can also be a list/tuple of integers
54        (one size per state). In this case, the first entry
55        (`state_size[0]`) should be the same as
56        the size of the cell output.
57    return_sequences: Boolean. Whether to return the last output.
58      in the output sequence, or the full sequence.
59    return_state: Boolean. Whether to return the last state
60      in addition to the output.
61    go_backwards: Boolean (default False).
62      If True, process the input sequence backwards and return the
63      reversed sequence.
64    stateful: Boolean (default False). If True, the last state
65      for each sample at index i in a batch will be used as initial
66      state for the sample of index i in the following batch.
67    input_shape: Use this argument to specify the shape of the
68      input when this layer is the first one in a model.
69
70  Call arguments:
71    inputs: A 5D tensor.
72    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
73      a given timestep should be masked.
74    training: Python boolean indicating whether the layer should behave in
75      training mode or in inference mode. This argument is passed to the cell
76      when calling it. This is for use with cells that use dropout.
77    initial_state: List of initial state tensors to be passed to the first
78      call of the cell.
79    constants: List of constant tensors to be passed to the cell at each
80      timestep.
81
82  Input shape:
83    5D tensor with shape:
84    `(samples, timesteps, channels, rows, cols)`
85    if data_format='channels_first' or 5D tensor with shape:
86    `(samples, timesteps, rows, cols, channels)`
87    if data_format='channels_last'.
88
89  Output shape:
90    - If `return_state`: a list of tensors. The first tensor is
91      the output. The remaining tensors are the last states,
92      each 4D tensor with shape:
93      `(samples, filters, new_rows, new_cols)`
94      if data_format='channels_first'
95      or 4D tensor with shape:
96      `(samples, new_rows, new_cols, filters)`
97      if data_format='channels_last'.
98      `rows` and `cols` values might have changed due to padding.
99    - If `return_sequences`: 5D tensor with shape:
100      `(samples, timesteps, filters, new_rows, new_cols)`
101      if data_format='channels_first'
102      or 5D tensor with shape:
103      `(samples, timesteps, new_rows, new_cols, filters)`
104      if data_format='channels_last'.
105    - Else, 4D tensor with shape:
106      `(samples, filters, new_rows, new_cols)`
107      if data_format='channels_first'
108      or 4D tensor with shape:
109      `(samples, new_rows, new_cols, filters)`
110      if data_format='channels_last'.
111
112  Masking:
113    This layer supports masking for input data with a variable number
114    of timesteps.
115
116  Note on using statefulness in RNNs:
117    You can set RNN layers to be 'stateful', which means that the states
118    computed for the samples in one batch will be reused as initial states
119    for the samples in the next batch. This assumes a one-to-one mapping
120    between samples in different successive batches.
121    To enable statefulness:
122      - Specify `stateful=True` in the layer constructor.
123      - Specify a fixed batch size for your model, by passing
124         - If sequential model:
125            `batch_input_shape=(...)` to the first layer in your model.
126         - If functional model with 1 or more Input layers:
127            `batch_shape=(...)` to all the first layers in your model.
128            This is the expected shape of your inputs
129            *including the batch size*.
130            It should be a tuple of integers,
131            e.g. `(32, 10, 100, 100, 32)`.
132            Note that the number of rows and columns should be specified
133            too.
134      - Specify `shuffle=False` when calling fit().
135    To reset the states of your model, call `.reset_states()` on either
136    a specific layer, or on your entire model.
137
138  Note on specifying the initial state of RNNs:
139    You can specify the initial state of RNN layers symbolically by
140    calling them with the keyword argument `initial_state`. The value of
141    `initial_state` should be a tensor or list of tensors representing
142    the initial state of the RNN layer.
143    You can specify the initial state of RNN layers numerically by
144    calling `reset_states` with the keyword argument `states`. The value of
145    `states` should be a numpy array or list of numpy arrays representing
146    the initial state of the RNN layer.
147
148  Note on passing external constants to RNNs:
149    You can pass "external" constants to the cell using the `constants`
150    keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
151    requires that the `cell.call` method accepts the same keyword argument
152    `constants`. Such constants can be used to condition the cell
153    transformation on additional static inputs (not changing over time),
154    a.k.a. an attention mechanism.
155  """
156
157  def __init__(self,
158               cell,
159               return_sequences=False,
160               return_state=False,
161               go_backwards=False,
162               stateful=False,
163               unroll=False,
164               **kwargs):
165    if unroll:
166      raise TypeError('Unrolling isn\'t possible with '
167                      'convolutional RNNs.')
168    if isinstance(cell, (list, tuple)):
169      # The StackedConvRNN2DCells isn't implemented yet.
170      raise TypeError('It is not possible at the moment to'
171                      'stack convolutional cells.')
172    super(ConvRNN2D, self).__init__(cell,
173                                    return_sequences,
174                                    return_state,
175                                    go_backwards,
176                                    stateful,
177                                    unroll,
178                                    **kwargs)
179    self.input_spec = [InputSpec(ndim=5)]
180    self.states = None
181    self._num_constants = None
182
183  @tf_utils.shape_type_conversion
184  def compute_output_shape(self, input_shape):
185    if isinstance(input_shape, list):
186      input_shape = input_shape[0]
187
188    cell = self.cell
189    if cell.data_format == 'channels_first':
190      rows = input_shape[3]
191      cols = input_shape[4]
192    elif cell.data_format == 'channels_last':
193      rows = input_shape[2]
194      cols = input_shape[3]
195    rows = conv_utils.conv_output_length(rows,
196                                         cell.kernel_size[0],
197                                         padding=cell.padding,
198                                         stride=cell.strides[0],
199                                         dilation=cell.dilation_rate[0])
200    cols = conv_utils.conv_output_length(cols,
201                                         cell.kernel_size[1],
202                                         padding=cell.padding,
203                                         stride=cell.strides[1],
204                                         dilation=cell.dilation_rate[1])
205
206    if cell.data_format == 'channels_first':
207      output_shape = input_shape[:2] + (cell.filters, rows, cols)
208    elif cell.data_format == 'channels_last':
209      output_shape = input_shape[:2] + (rows, cols, cell.filters)
210
211    if not self.return_sequences:
212      output_shape = output_shape[:1] + output_shape[2:]
213
214    if self.return_state:
215      output_shape = [output_shape]
216      if cell.data_format == 'channels_first':
217        output_shape += [(input_shape[0], cell.filters, rows, cols)
218                         for _ in range(2)]
219      elif cell.data_format == 'channels_last':
220        output_shape += [(input_shape[0], rows, cols, cell.filters)
221                         for _ in range(2)]
222    return output_shape
223
224  @tf_utils.shape_type_conversion
225  def build(self, input_shape):
226    # Note input_shape will be list of shapes of initial states and
227    # constants if these are passed in __call__.
228    if self._num_constants is not None:
229      constants_shape = input_shape[-self._num_constants:]  # pylint: disable=E1130
230    else:
231      constants_shape = None
232
233    if isinstance(input_shape, list):
234      input_shape = input_shape[0]
235
236    batch_size = input_shape[0] if self.stateful else None
237    self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5])
238
239    # allow cell (if layer) to build before we set or validate state_spec
240    if isinstance(self.cell, Layer):
241      step_input_shape = (input_shape[0],) + input_shape[2:]
242      if constants_shape is not None:
243        self.cell.build([step_input_shape] + constants_shape)
244      else:
245        self.cell.build(step_input_shape)
246
247    # set or validate state_spec
248    if hasattr(self.cell.state_size, '__len__'):
249      state_size = list(self.cell.state_size)
250    else:
251      state_size = [self.cell.state_size]
252
253    if self.state_spec is not None:
254      # initial_state was passed in call, check compatibility
255      if self.cell.data_format == 'channels_first':
256        ch_dim = 1
257      elif self.cell.data_format == 'channels_last':
258        ch_dim = 3
259      if [spec.shape[ch_dim] for spec in self.state_spec] != state_size:
260        raise ValueError(
261            'An initial_state was passed that is not compatible with '
262            '`cell.state_size`. Received `state_spec`={}; '
263            'However `cell.state_size` is '
264            '{}'.format([spec.shape for spec in self.state_spec],
265                        self.cell.state_size))
266    else:
267      if self.cell.data_format == 'channels_first':
268        self.state_spec = [InputSpec(shape=(None, dim, None, None))
269                           for dim in state_size]
270      elif self.cell.data_format == 'channels_last':
271        self.state_spec = [InputSpec(shape=(None, None, None, dim))
272                           for dim in state_size]
273    if self.stateful:
274      self.reset_states()
275    self.built = True
276
277  def get_initial_state(self, inputs):
278    # (samples, timesteps, rows, cols, filters)
279    initial_state = K.zeros_like(inputs)
280    # (samples, rows, cols, filters)
281    initial_state = K.sum(initial_state, axis=1)
282    shape = list(self.cell.kernel_shape)
283    shape[-1] = self.cell.filters
284    initial_state = self.cell.input_conv(initial_state,
285                                         array_ops.zeros(tuple(shape),
286                                                         initial_state.dtype),
287                                         padding=self.cell.padding)
288
289    if hasattr(self.cell.state_size, '__len__'):
290      return [initial_state for _ in self.cell.state_size]
291    else:
292      return [initial_state]
293
294  def call(self,
295           inputs,
296           mask=None,
297           training=None,
298           initial_state=None,
299           constants=None):
300    # note that the .build() method of subclasses MUST define
301    # self.input_spec and self.state_spec with complete input shapes.
302    inputs, initial_state, constants = self._process_inputs(
303        inputs, initial_state, constants)
304
305    if isinstance(mask, list):
306      mask = mask[0]
307    timesteps = K.int_shape(inputs)[1]
308
309    kwargs = {}
310    if generic_utils.has_arg(self.cell.call, 'training'):
311      kwargs['training'] = training
312
313    if constants:
314      if not generic_utils.has_arg(self.cell.call, 'constants'):
315        raise ValueError('RNN cell does not support constants')
316
317      def step(inputs, states):
318        constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
319        states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
320        return self.cell.call(inputs, states, constants=constants, **kwargs)
321    else:
322      def step(inputs, states):
323        return self.cell.call(inputs, states, **kwargs)
324
325    last_output, outputs, states = K.rnn(step,
326                                         inputs,
327                                         initial_state,
328                                         constants=constants,
329                                         go_backwards=self.go_backwards,
330                                         mask=mask,
331                                         input_length=timesteps)
332    if self.stateful:
333      updates = [
334          K.update(self_state, state)
335          for self_state, state in zip(self.states, states)
336      ]
337      self.add_update(updates)
338
339    if self.return_sequences:
340      output = outputs
341    else:
342      output = last_output
343
344    if self.return_state:
345      if not isinstance(states, (list, tuple)):
346        states = [states]
347      else:
348        states = list(states)
349      return [output] + states
350    else:
351      return output
352
353  def reset_states(self, states=None):
354    if not self.stateful:
355      raise AttributeError('Layer must be stateful.')
356    input_shape = self.input_spec[0].shape
357    state_shape = self.compute_output_shape(input_shape)
358    if self.return_state:
359      state_shape = state_shape[0]
360    if self.return_sequences:
361      state_shape = state_shape[:1].concatenate(state_shape[2:])
362    if None in state_shape:
363      raise ValueError('If a RNN is stateful, it needs to know '
364                       'its batch size. Specify the batch size '
365                       'of your input tensors: \n'
366                       '- If using a Sequential model, '
367                       'specify the batch size by passing '
368                       'a `batch_input_shape` '
369                       'argument to your first layer.\n'
370                       '- If using the functional API, specify '
371                       'the time dimension by passing a '
372                       '`batch_shape` argument to your Input layer.\n'
373                       'The same thing goes for the number of rows and '
374                       'columns.')
375
376    # helper function
377    def get_tuple_shape(nb_channels):
378      result = list(state_shape)
379      if self.cell.data_format == 'channels_first':
380        result[1] = nb_channels
381      elif self.cell.data_format == 'channels_last':
382        result[3] = nb_channels
383      else:
384        raise KeyError
385      return tuple(result)
386
387    # initialize state if None
388    if self.states[0] is None:
389      if hasattr(self.cell.state_size, '__len__'):
390        self.states = [K.zeros(get_tuple_shape(dim))
391                       for dim in self.cell.state_size]
392      else:
393        self.states = [K.zeros(get_tuple_shape(self.cell.state_size))]
394    elif states is None:
395      if hasattr(self.cell.state_size, '__len__'):
396        for state, dim in zip(self.states, self.cell.state_size):
397          K.set_value(state, np.zeros(get_tuple_shape(dim)))
398      else:
399        K.set_value(self.states[0],
400                    np.zeros(get_tuple_shape(self.cell.state_size)))
401    else:
402      if not isinstance(states, (list, tuple)):
403        states = [states]
404      if len(states) != len(self.states):
405        raise ValueError('Layer ' + self.name + ' expects ' +
406                         str(len(self.states)) + ' states, ' +
407                         'but it received ' + str(len(states)) +
408                         ' state values. Input received: ' + str(states))
409      for index, (value, state) in enumerate(zip(states, self.states)):
410        if hasattr(self.cell.state_size, '__len__'):
411          dim = self.cell.state_size[index]
412        else:
413          dim = self.cell.state_size
414        if value.shape != get_tuple_shape(dim):
415          raise ValueError('State ' + str(index) +
416                           ' is incompatible with layer ' +
417                           self.name + ': expected shape=' +
418                           str(get_tuple_shape(dim)) +
419                           ', found shape=' + str(value.shape))
420        # TODO(anjalisridhar): consider batch calls to `set_value`.
421        K.set_value(state, value)
422
423
424class ConvLSTM2DCell(DropoutRNNCellMixin, Layer):
425  """Cell class for the ConvLSTM2D layer.
426
427  Args:
428    filters: Integer, the dimensionality of the output space
429      (i.e. the number of output filters in the convolution).
430    kernel_size: An integer or tuple/list of n integers, specifying the
431      dimensions of the convolution window.
432    strides: An integer or tuple/list of n integers,
433      specifying the strides of the convolution.
434      Specifying any stride value != 1 is incompatible with specifying
435      any `dilation_rate` value != 1.
436    padding: One of `"valid"` or `"same"` (case-insensitive).
437      `"valid"` means no padding. `"same"` results in padding evenly to
438      the left/right or up/down of the input such that output has the same
439      height/width dimension as the input.
440    data_format: A string,
441      one of `channels_last` (default) or `channels_first`.
442      It defaults to the `image_data_format` value found in your
443      Keras config file at `~/.keras/keras.json`.
444      If you never set it, then it will be "channels_last".
445    dilation_rate: An integer or tuple/list of n integers, specifying
446      the dilation rate to use for dilated convolution.
447      Currently, specifying any `dilation_rate` value != 1 is
448      incompatible with specifying any `strides` value != 1.
449    activation: Activation function to use.
450      If you don't specify anything, no activation is applied
451      (ie. "linear" activation: `a(x) = x`).
452    recurrent_activation: Activation function to use
453      for the recurrent step.
454    use_bias: Boolean, whether the layer uses a bias vector.
455    kernel_initializer: Initializer for the `kernel` weights matrix,
456      used for the linear transformation of the inputs.
457    recurrent_initializer: Initializer for the `recurrent_kernel`
458      weights matrix,
459      used for the linear transformation of the recurrent state.
460    bias_initializer: Initializer for the bias vector.
461    unit_forget_bias: Boolean.
462      If True, add 1 to the bias of the forget gate at initialization.
463      Use in combination with `bias_initializer="zeros"`.
464      This is recommended in [Jozefowicz et al., 2015](
465        http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
466    kernel_regularizer: Regularizer function applied to
467      the `kernel` weights matrix.
468    recurrent_regularizer: Regularizer function applied to
469      the `recurrent_kernel` weights matrix.
470    bias_regularizer: Regularizer function applied to the bias vector.
471    kernel_constraint: Constraint function applied to
472      the `kernel` weights matrix.
473    recurrent_constraint: Constraint function applied to
474      the `recurrent_kernel` weights matrix.
475    bias_constraint: Constraint function applied to the bias vector.
476    dropout: Float between 0 and 1.
477      Fraction of the units to drop for
478      the linear transformation of the inputs.
479    recurrent_dropout: Float between 0 and 1.
480      Fraction of the units to drop for
481      the linear transformation of the recurrent state.
482
483  Call arguments:
484    inputs: A 4D tensor.
485    states:  List of state tensors corresponding to the previous timestep.
486    training: Python boolean indicating whether the layer should behave in
487      training mode or in inference mode. Only relevant when `dropout` or
488      `recurrent_dropout` is used.
489  """
490
491  def __init__(self,
492               filters,
493               kernel_size,
494               strides=(1, 1),
495               padding='valid',
496               data_format=None,
497               dilation_rate=(1, 1),
498               activation='tanh',
499               recurrent_activation='hard_sigmoid',
500               use_bias=True,
501               kernel_initializer='glorot_uniform',
502               recurrent_initializer='orthogonal',
503               bias_initializer='zeros',
504               unit_forget_bias=True,
505               kernel_regularizer=None,
506               recurrent_regularizer=None,
507               bias_regularizer=None,
508               kernel_constraint=None,
509               recurrent_constraint=None,
510               bias_constraint=None,
511               dropout=0.,
512               recurrent_dropout=0.,
513               **kwargs):
514    super(ConvLSTM2DCell, self).__init__(**kwargs)
515    self.filters = filters
516    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
517    self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
518    self.padding = conv_utils.normalize_padding(padding)
519    self.data_format = conv_utils.normalize_data_format(data_format)
520    self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2,
521                                                    'dilation_rate')
522    self.activation = activations.get(activation)
523    self.recurrent_activation = activations.get(recurrent_activation)
524    self.use_bias = use_bias
525
526    self.kernel_initializer = initializers.get(kernel_initializer)
527    self.recurrent_initializer = initializers.get(recurrent_initializer)
528    self.bias_initializer = initializers.get(bias_initializer)
529    self.unit_forget_bias = unit_forget_bias
530
531    self.kernel_regularizer = regularizers.get(kernel_regularizer)
532    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
533    self.bias_regularizer = regularizers.get(bias_regularizer)
534
535    self.kernel_constraint = constraints.get(kernel_constraint)
536    self.recurrent_constraint = constraints.get(recurrent_constraint)
537    self.bias_constraint = constraints.get(bias_constraint)
538
539    self.dropout = min(1., max(0., dropout))
540    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
541    self.state_size = (self.filters, self.filters)
542
543  def build(self, input_shape):
544
545    if self.data_format == 'channels_first':
546      channel_axis = 1
547    else:
548      channel_axis = -1
549    if input_shape[channel_axis] is None:
550      raise ValueError('The channel dimension of the inputs '
551                       'should be defined. Found `None`.')
552    input_dim = input_shape[channel_axis]
553    kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
554    self.kernel_shape = kernel_shape
555    recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
556
557    self.kernel = self.add_weight(shape=kernel_shape,
558                                  initializer=self.kernel_initializer,
559                                  name='kernel',
560                                  regularizer=self.kernel_regularizer,
561                                  constraint=self.kernel_constraint)
562    self.recurrent_kernel = self.add_weight(
563        shape=recurrent_kernel_shape,
564        initializer=self.recurrent_initializer,
565        name='recurrent_kernel',
566        regularizer=self.recurrent_regularizer,
567        constraint=self.recurrent_constraint)
568
569    if self.use_bias:
570      if self.unit_forget_bias:
571
572        def bias_initializer(_, *args, **kwargs):
573          return K.concatenate([
574              self.bias_initializer((self.filters,), *args, **kwargs),
575              initializers.get('ones')((self.filters,), *args, **kwargs),
576              self.bias_initializer((self.filters * 2,), *args, **kwargs),
577          ])
578      else:
579        bias_initializer = self.bias_initializer
580      self.bias = self.add_weight(
581          shape=(self.filters * 4,),
582          name='bias',
583          initializer=bias_initializer,
584          regularizer=self.bias_regularizer,
585          constraint=self.bias_constraint)
586    else:
587      self.bias = None
588    self.built = True
589
590  def call(self, inputs, states, training=None):
591    h_tm1 = states[0]  # previous memory state
592    c_tm1 = states[1]  # previous carry state
593
594    # dropout matrices for input units
595    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
596    # dropout matrices for recurrent units
597    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
598        h_tm1, training, count=4)
599
600    if 0 < self.dropout < 1.:
601      inputs_i = inputs * dp_mask[0]
602      inputs_f = inputs * dp_mask[1]
603      inputs_c = inputs * dp_mask[2]
604      inputs_o = inputs * dp_mask[3]
605    else:
606      inputs_i = inputs
607      inputs_f = inputs
608      inputs_c = inputs
609      inputs_o = inputs
610
611    if 0 < self.recurrent_dropout < 1.:
612      h_tm1_i = h_tm1 * rec_dp_mask[0]
613      h_tm1_f = h_tm1 * rec_dp_mask[1]
614      h_tm1_c = h_tm1 * rec_dp_mask[2]
615      h_tm1_o = h_tm1 * rec_dp_mask[3]
616    else:
617      h_tm1_i = h_tm1
618      h_tm1_f = h_tm1
619      h_tm1_c = h_tm1
620      h_tm1_o = h_tm1
621
622    (kernel_i, kernel_f,
623     kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3)
624    (recurrent_kernel_i,
625     recurrent_kernel_f,
626     recurrent_kernel_c,
627     recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3)
628
629    if self.use_bias:
630      bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4)
631    else:
632      bias_i, bias_f, bias_c, bias_o = None, None, None, None
633
634    x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
635    x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
636    x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
637    x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
638    h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
639    h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
640    h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
641    h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
642
643    i = self.recurrent_activation(x_i + h_i)
644    f = self.recurrent_activation(x_f + h_f)
645    c = f * c_tm1 + i * self.activation(x_c + h_c)
646    o = self.recurrent_activation(x_o + h_o)
647    h = o * self.activation(c)
648    return h, [h, c]
649
650  def input_conv(self, x, w, b=None, padding='valid'):
651    conv_out = K.conv2d(x, w, strides=self.strides,
652                        padding=padding,
653                        data_format=self.data_format,
654                        dilation_rate=self.dilation_rate)
655    if b is not None:
656      conv_out = K.bias_add(conv_out, b,
657                            data_format=self.data_format)
658    return conv_out
659
660  def recurrent_conv(self, x, w):
661    conv_out = K.conv2d(x, w, strides=(1, 1),
662                        padding='same',
663                        data_format=self.data_format)
664    return conv_out
665
666  def get_config(self):
667    config = {'filters': self.filters,
668              'kernel_size': self.kernel_size,
669              'strides': self.strides,
670              'padding': self.padding,
671              'data_format': self.data_format,
672              'dilation_rate': self.dilation_rate,
673              'activation': activations.serialize(self.activation),
674              'recurrent_activation': activations.serialize(
675                  self.recurrent_activation),
676              'use_bias': self.use_bias,
677              'kernel_initializer': initializers.serialize(
678                  self.kernel_initializer),
679              'recurrent_initializer': initializers.serialize(
680                  self.recurrent_initializer),
681              'bias_initializer': initializers.serialize(self.bias_initializer),
682              'unit_forget_bias': self.unit_forget_bias,
683              'kernel_regularizer': regularizers.serialize(
684                  self.kernel_regularizer),
685              'recurrent_regularizer': regularizers.serialize(
686                  self.recurrent_regularizer),
687              'bias_regularizer': regularizers.serialize(self.bias_regularizer),
688              'kernel_constraint': constraints.serialize(
689                  self.kernel_constraint),
690              'recurrent_constraint': constraints.serialize(
691                  self.recurrent_constraint),
692              'bias_constraint': constraints.serialize(self.bias_constraint),
693              'dropout': self.dropout,
694              'recurrent_dropout': self.recurrent_dropout}
695    base_config = super(ConvLSTM2DCell, self).get_config()
696    return dict(list(base_config.items()) + list(config.items()))
697
698
699@keras_export('keras.layers.ConvLSTM2D')
700class ConvLSTM2D(ConvRNN2D):
701  """Convolutional LSTM.
702
703  It is similar to an LSTM layer, but the input transformations
704  and recurrent transformations are both convolutional.
705
706  Args:
707    filters: Integer, the dimensionality of the output space
708      (i.e. the number of output filters in the convolution).
709    kernel_size: An integer or tuple/list of n integers, specifying the
710      dimensions of the convolution window.
711    strides: An integer or tuple/list of n integers,
712      specifying the strides of the convolution.
713      Specifying any stride value != 1 is incompatible with specifying
714      any `dilation_rate` value != 1.
715    padding: One of `"valid"` or `"same"` (case-insensitive).
716      `"valid"` means no padding. `"same"` results in padding evenly to
717      the left/right or up/down of the input such that output has the same
718      height/width dimension as the input.
719    data_format: A string,
720      one of `channels_last` (default) or `channels_first`.
721      The ordering of the dimensions in the inputs.
722      `channels_last` corresponds to inputs with shape
723      `(batch, time, ..., channels)`
724      while `channels_first` corresponds to
725      inputs with shape `(batch, time, channels, ...)`.
726      It defaults to the `image_data_format` value found in your
727      Keras config file at `~/.keras/keras.json`.
728      If you never set it, then it will be "channels_last".
729    dilation_rate: An integer or tuple/list of n integers, specifying
730      the dilation rate to use for dilated convolution.
731      Currently, specifying any `dilation_rate` value != 1 is
732      incompatible with specifying any `strides` value != 1.
733    activation: Activation function to use.
734      By default hyperbolic tangent activation function is applied
735      (`tanh(x)`).
736    recurrent_activation: Activation function to use
737      for the recurrent step.
738    use_bias: Boolean, whether the layer uses a bias vector.
739    kernel_initializer: Initializer for the `kernel` weights matrix,
740      used for the linear transformation of the inputs.
741    recurrent_initializer: Initializer for the `recurrent_kernel`
742      weights matrix,
743      used for the linear transformation of the recurrent state.
744    bias_initializer: Initializer for the bias vector.
745    unit_forget_bias: Boolean.
746      If True, add 1 to the bias of the forget gate at initialization.
747      Use in combination with `bias_initializer="zeros"`.
748      This is recommended in [Jozefowicz et al., 2015](
749        http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
750    kernel_regularizer: Regularizer function applied to
751      the `kernel` weights matrix.
752    recurrent_regularizer: Regularizer function applied to
753      the `recurrent_kernel` weights matrix.
754    bias_regularizer: Regularizer function applied to the bias vector.
755    activity_regularizer: Regularizer function applied to.
756    kernel_constraint: Constraint function applied to
757      the `kernel` weights matrix.
758    recurrent_constraint: Constraint function applied to
759      the `recurrent_kernel` weights matrix.
760    bias_constraint: Constraint function applied to the bias vector.
761    return_sequences: Boolean. Whether to return the last output
762      in the output sequence, or the full sequence. (default False)
763    return_state: Boolean Whether to return the last state
764      in addition to the output. (default False)
765    go_backwards: Boolean (default False).
766      If True, process the input sequence backwards.
767    stateful: Boolean (default False). If True, the last state
768      for each sample at index i in a batch will be used as initial
769      state for the sample of index i in the following batch.
770    dropout: Float between 0 and 1.
771      Fraction of the units to drop for
772      the linear transformation of the inputs.
773    recurrent_dropout: Float between 0 and 1.
774      Fraction of the units to drop for
775      the linear transformation of the recurrent state.
776
777  Call arguments:
778    inputs: A 5D tensor.
779    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
780      a given timestep should be masked.
781    training: Python boolean indicating whether the layer should behave in
782      training mode or in inference mode. This argument is passed to the cell
783      when calling it. This is only relevant if `dropout` or `recurrent_dropout`
784      are set.
785    initial_state: List of initial state tensors to be passed to the first
786      call of the cell.
787
788  Input shape:
789    - If data_format='channels_first'
790        5D tensor with shape:
791        `(samples, time, channels, rows, cols)`
792    - If data_format='channels_last'
793        5D tensor with shape:
794        `(samples, time, rows, cols, channels)`
795
796  Output shape:
797    - If `return_state`: a list of tensors. The first tensor is
798      the output. The remaining tensors are the last states,
799      each 4D tensor with shape:
800      `(samples, filters, new_rows, new_cols)`
801      if data_format='channels_first'
802      or 4D tensor with shape:
803      `(samples, new_rows, new_cols, filters)`
804      if data_format='channels_last'.
805      `rows` and `cols` values might have changed due to padding.
806    - If `return_sequences`: 5D tensor with shape:
807      `(samples, timesteps, filters, new_rows, new_cols)`
808      if data_format='channels_first'
809      or 5D tensor with shape:
810      `(samples, timesteps, new_rows, new_cols, filters)`
811      if data_format='channels_last'.
812    - Else, 4D tensor with shape:
813      `(samples, filters, new_rows, new_cols)`
814      if data_format='channels_first'
815      or 4D tensor with shape:
816      `(samples, new_rows, new_cols, filters)`
817      if data_format='channels_last'.
818
819  Raises:
820    ValueError: in case of invalid constructor arguments.
821
822  References:
823    - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)
824    (the current implementation does not include the feedback loop on the
825    cells output).
826  """
827
828  def __init__(self,
829               filters,
830               kernel_size,
831               strides=(1, 1),
832               padding='valid',
833               data_format=None,
834               dilation_rate=(1, 1),
835               activation='tanh',
836               recurrent_activation='hard_sigmoid',
837               use_bias=True,
838               kernel_initializer='glorot_uniform',
839               recurrent_initializer='orthogonal',
840               bias_initializer='zeros',
841               unit_forget_bias=True,
842               kernel_regularizer=None,
843               recurrent_regularizer=None,
844               bias_regularizer=None,
845               activity_regularizer=None,
846               kernel_constraint=None,
847               recurrent_constraint=None,
848               bias_constraint=None,
849               return_sequences=False,
850               return_state=False,
851               go_backwards=False,
852               stateful=False,
853               dropout=0.,
854               recurrent_dropout=0.,
855               **kwargs):
856    cell = ConvLSTM2DCell(filters=filters,
857                          kernel_size=kernel_size,
858                          strides=strides,
859                          padding=padding,
860                          data_format=data_format,
861                          dilation_rate=dilation_rate,
862                          activation=activation,
863                          recurrent_activation=recurrent_activation,
864                          use_bias=use_bias,
865                          kernel_initializer=kernel_initializer,
866                          recurrent_initializer=recurrent_initializer,
867                          bias_initializer=bias_initializer,
868                          unit_forget_bias=unit_forget_bias,
869                          kernel_regularizer=kernel_regularizer,
870                          recurrent_regularizer=recurrent_regularizer,
871                          bias_regularizer=bias_regularizer,
872                          kernel_constraint=kernel_constraint,
873                          recurrent_constraint=recurrent_constraint,
874                          bias_constraint=bias_constraint,
875                          dropout=dropout,
876                          recurrent_dropout=recurrent_dropout,
877                          dtype=kwargs.get('dtype'))
878    super(ConvLSTM2D, self).__init__(cell,
879                                     return_sequences=return_sequences,
880                                     return_state=return_state,
881                                     go_backwards=go_backwards,
882                                     stateful=stateful,
883                                     **kwargs)
884    self.activity_regularizer = regularizers.get(activity_regularizer)
885
886  def call(self, inputs, mask=None, training=None, initial_state=None):
887    return super(ConvLSTM2D, self).call(inputs,
888                                        mask=mask,
889                                        training=training,
890                                        initial_state=initial_state)
891
892  @property
893  def filters(self):
894    return self.cell.filters
895
896  @property
897  def kernel_size(self):
898    return self.cell.kernel_size
899
900  @property
901  def strides(self):
902    return self.cell.strides
903
904  @property
905  def padding(self):
906    return self.cell.padding
907
908  @property
909  def data_format(self):
910    return self.cell.data_format
911
912  @property
913  def dilation_rate(self):
914    return self.cell.dilation_rate
915
916  @property
917  def activation(self):
918    return self.cell.activation
919
920  @property
921  def recurrent_activation(self):
922    return self.cell.recurrent_activation
923
924  @property
925  def use_bias(self):
926    return self.cell.use_bias
927
928  @property
929  def kernel_initializer(self):
930    return self.cell.kernel_initializer
931
932  @property
933  def recurrent_initializer(self):
934    return self.cell.recurrent_initializer
935
936  @property
937  def bias_initializer(self):
938    return self.cell.bias_initializer
939
940  @property
941  def unit_forget_bias(self):
942    return self.cell.unit_forget_bias
943
944  @property
945  def kernel_regularizer(self):
946    return self.cell.kernel_regularizer
947
948  @property
949  def recurrent_regularizer(self):
950    return self.cell.recurrent_regularizer
951
952  @property
953  def bias_regularizer(self):
954    return self.cell.bias_regularizer
955
956  @property
957  def kernel_constraint(self):
958    return self.cell.kernel_constraint
959
960  @property
961  def recurrent_constraint(self):
962    return self.cell.recurrent_constraint
963
964  @property
965  def bias_constraint(self):
966    return self.cell.bias_constraint
967
968  @property
969  def dropout(self):
970    return self.cell.dropout
971
972  @property
973  def recurrent_dropout(self):
974    return self.cell.recurrent_dropout
975
976  def get_config(self):
977    config = {'filters': self.filters,
978              'kernel_size': self.kernel_size,
979              'strides': self.strides,
980              'padding': self.padding,
981              'data_format': self.data_format,
982              'dilation_rate': self.dilation_rate,
983              'activation': activations.serialize(self.activation),
984              'recurrent_activation': activations.serialize(
985                  self.recurrent_activation),
986              'use_bias': self.use_bias,
987              'kernel_initializer': initializers.serialize(
988                  self.kernel_initializer),
989              'recurrent_initializer': initializers.serialize(
990                  self.recurrent_initializer),
991              'bias_initializer': initializers.serialize(self.bias_initializer),
992              'unit_forget_bias': self.unit_forget_bias,
993              'kernel_regularizer': regularizers.serialize(
994                  self.kernel_regularizer),
995              'recurrent_regularizer': regularizers.serialize(
996                  self.recurrent_regularizer),
997              'bias_regularizer': regularizers.serialize(self.bias_regularizer),
998              'activity_regularizer': regularizers.serialize(
999                  self.activity_regularizer),
1000              'kernel_constraint': constraints.serialize(
1001                  self.kernel_constraint),
1002              'recurrent_constraint': constraints.serialize(
1003                  self.recurrent_constraint),
1004              'bias_constraint': constraints.serialize(self.bias_constraint),
1005              'dropout': self.dropout,
1006              'recurrent_dropout': self.recurrent_dropout}
1007    base_config = super(ConvLSTM2D, self).get_config()
1008    del base_config['cell']
1009    return dict(list(base_config.items()) + list(config.items()))
1010
1011  @classmethod
1012  def from_config(cls, config):
1013    return cls(**config)
1014