1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Recurrent layers and their base classes.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import warnings
24
25import numpy as np
26
27from tensorflow.python.distribute import distribution_strategy_context as ds_context
28from tensorflow.python.eager import context
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.keras import activations
32from tensorflow.python.keras import backend as K
33from tensorflow.python.keras import constraints
34from tensorflow.python.keras import initializers
35from tensorflow.python.keras import regularizers
36from tensorflow.python.keras.engine.base_layer import Layer
37from tensorflow.python.keras.engine.input_spec import InputSpec
38from tensorflow.python.keras.saving.saved_model import layer_serialization
39from tensorflow.python.keras.utils import control_flow_util
40from tensorflow.python.keras.utils import generic_utils
41from tensorflow.python.keras.utils import tf_utils
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import state_ops
46from tensorflow.python.platform import tf_logging as logging
47from tensorflow.python.training.tracking import base as trackable
48from tensorflow.python.training.tracking import data_structures
49from tensorflow.python.util import nest
50from tensorflow.python.util.tf_export import keras_export
51from tensorflow.tools.docs import doc_controls
52
53
54RECURRENT_DROPOUT_WARNING_MSG = (
55    'RNN `implementation=2` is not supported when `recurrent_dropout` is set. '
56    'Using `implementation=1`.')
57
58
59@keras_export('keras.layers.StackedRNNCells')
60class StackedRNNCells(Layer):
61  """Wrapper allowing a stack of RNN cells to behave as a single cell.
62
63  Used to implement efficient stacked RNNs.
64
65  Args:
66    cells: List of RNN cell instances.
67
68  Examples:
69
70  ```python
71  batch_size = 3
72  sentence_max_length = 5
73  n_features = 2
74  new_shape = (batch_size, sentence_max_length, n_features)
75  x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32)
76
77  rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)]
78  stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells)
79  lstm_layer = tf.keras.layers.RNN(stacked_lstm)
80
81  result = lstm_layer(x)
82  ```
83  """
84
85  def __init__(self, cells, **kwargs):
86    for cell in cells:
87      if not 'call' in dir(cell):
88        raise ValueError('All cells must have a `call` method. '
89                         'received cells:', cells)
90      if not 'state_size' in dir(cell):
91        raise ValueError('All cells must have a '
92                         '`state_size` attribute. '
93                         'received cells:', cells)
94    self.cells = cells
95    # reverse_state_order determines whether the state size will be in a reverse
96    # order of the cells' state. User might want to set this to True to keep the
97    # existing behavior. This is only useful when use RNN(return_state=True)
98    # since the state will be returned as the same order of state_size.
99    self.reverse_state_order = kwargs.pop('reverse_state_order', False)
100    if self.reverse_state_order:
101      logging.warning('reverse_state_order=True in StackedRNNCells will soon '
102                      'be deprecated. Please update the code to work with the '
103                      'natural order of states if you rely on the RNN states, '
104                      'eg RNN(return_state=True).')
105    super(StackedRNNCells, self).__init__(**kwargs)
106
107  @property
108  def state_size(self):
109    return tuple(c.state_size for c in
110                 (self.cells[::-1] if self.reverse_state_order else self.cells))
111
112  @property
113  def output_size(self):
114    if getattr(self.cells[-1], 'output_size', None) is not None:
115      return self.cells[-1].output_size
116    elif _is_multiple_state(self.cells[-1].state_size):
117      return self.cells[-1].state_size[0]
118    else:
119      return self.cells[-1].state_size
120
121  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
122    initial_states = []
123    for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
124      get_initial_state_fn = getattr(cell, 'get_initial_state', None)
125      if get_initial_state_fn:
126        initial_states.append(get_initial_state_fn(
127            inputs=inputs, batch_size=batch_size, dtype=dtype))
128      else:
129        initial_states.append(_generate_zero_filled_state_for_cell(
130            cell, inputs, batch_size, dtype))
131
132    return tuple(initial_states)
133
134  def call(self, inputs, states, constants=None, training=None, **kwargs):
135    # Recover per-cell states.
136    state_size = (self.state_size[::-1]
137                  if self.reverse_state_order else self.state_size)
138    nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))
139
140    # Call the cells in order and store the returned states.
141    new_nested_states = []
142    for cell, states in zip(self.cells, nested_states):
143      states = states if nest.is_nested(states) else [states]
144      # TF cell does not wrap the state into list when there is only one state.
145      is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
146      states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
147      if generic_utils.has_arg(cell.call, 'training'):
148        kwargs['training'] = training
149      else:
150        kwargs.pop('training', None)
151      # Use the __call__ function for callable objects, eg layers, so that it
152      # will have the proper name scopes for the ops, etc.
153      cell_call_fn = cell.__call__ if callable(cell) else cell.call
154      if generic_utils.has_arg(cell.call, 'constants'):
155        inputs, states = cell_call_fn(inputs, states,
156                                      constants=constants, **kwargs)
157      else:
158        inputs, states = cell_call_fn(inputs, states, **kwargs)
159      new_nested_states.append(states)
160
161    return inputs, nest.pack_sequence_as(state_size,
162                                         nest.flatten(new_nested_states))
163
164  @tf_utils.shape_type_conversion
165  def build(self, input_shape):
166    if isinstance(input_shape, list):
167      input_shape = input_shape[0]
168    for cell in self.cells:
169      if isinstance(cell, Layer) and not cell.built:
170        with K.name_scope(cell.name):
171          cell.build(input_shape)
172          cell.built = True
173      if getattr(cell, 'output_size', None) is not None:
174        output_dim = cell.output_size
175      elif _is_multiple_state(cell.state_size):
176        output_dim = cell.state_size[0]
177      else:
178        output_dim = cell.state_size
179      input_shape = tuple([input_shape[0]] +
180                          tensor_shape.TensorShape(output_dim).as_list())
181    self.built = True
182
183  def get_config(self):
184    cells = []
185    for cell in self.cells:
186      cells.append(generic_utils.serialize_keras_object(cell))
187    config = {'cells': cells}
188    base_config = super(StackedRNNCells, self).get_config()
189    return dict(list(base_config.items()) + list(config.items()))
190
191  @classmethod
192  def from_config(cls, config, custom_objects=None):
193    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
194    cells = []
195    for cell_config in config.pop('cells'):
196      cells.append(
197          deserialize_layer(cell_config, custom_objects=custom_objects))
198    return cls(cells, **config)
199
200
201@keras_export('keras.layers.RNN')
202class RNN(Layer):
203  """Base class for recurrent layers.
204
205  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
206  for details about the usage of RNN API.
207
208  Args:
209    cell: A RNN cell instance or a list of RNN cell instances.
210      A RNN cell is a class that has:
211      - A `call(input_at_t, states_at_t)` method, returning
212        `(output_at_t, states_at_t_plus_1)`. The call method of the
213        cell can also take the optional argument `constants`, see
214        section "Note on passing external constants" below.
215      - A `state_size` attribute. This can be a single integer
216        (single state) in which case it is the size of the recurrent
217        state. This can also be a list/tuple of integers (one size per state).
218        The `state_size` can also be TensorShape or tuple/list of
219        TensorShape, to represent high dimension state.
220      - A `output_size` attribute. This can be a single integer or a
221        TensorShape, which represent the shape of the output. For backward
222        compatible reason, if this attribute is not available for the
223        cell, the value will be inferred by the first element of the
224        `state_size`.
225      - A `get_initial_state(inputs=None, batch_size=None, dtype=None)`
226        method that creates a tensor meant to be fed to `call()` as the
227        initial state, if the user didn't specify any initial state via other
228        means. The returned initial state should have a shape of
229        [batch_size, cell.state_size]. The cell might choose to create a
230        tensor full of zeros, or full of other values based on the cell's
231        implementation.
232        `inputs` is the input tensor to the RNN layer, which should
233        contain the batch size as its shape[0], and also dtype. Note that
234        the shape[0] might be `None` during the graph construction. Either
235        the `inputs` or the pair of `batch_size` and `dtype` are provided.
236        `batch_size` is a scalar tensor that represents the batch size
237        of the inputs. `dtype` is `tf.DType` that represents the dtype of
238        the inputs.
239        For backward compatibility, if this method is not implemented
240        by the cell, the RNN layer will create a zero filled tensor with the
241        size of [batch_size, cell.state_size].
242      In the case that `cell` is a list of RNN cell instances, the cells
243      will be stacked on top of each other in the RNN, resulting in an
244      efficient stacked RNN.
245    return_sequences: Boolean (default `False`). Whether to return the last
246      output in the output sequence, or the full sequence.
247    return_state: Boolean (default `False`). Whether to return the last state
248      in addition to the output.
249    go_backwards: Boolean (default `False`).
250      If True, process the input sequence backwards and return the
251      reversed sequence.
252    stateful: Boolean (default `False`). If True, the last state
253      for each sample at index i in a batch will be used as initial
254      state for the sample of index i in the following batch.
255    unroll: Boolean (default `False`).
256      If True, the network will be unrolled, else a symbolic loop will be used.
257      Unrolling can speed-up a RNN, although it tends to be more
258      memory-intensive. Unrolling is only suitable for short sequences.
259    time_major: The shape format of the `inputs` and `outputs` tensors.
260      If True, the inputs and outputs will be in shape
261      `(timesteps, batch, ...)`, whereas in the False case, it will be
262      `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
263      efficient because it avoids transposes at the beginning and end of the
264      RNN calculation. However, most TensorFlow data is batch-major, so by
265      default this function accepts input and emits output in batch-major
266      form.
267    zero_output_for_mask: Boolean (default `False`).
268      Whether the output should use zeros for the masked timesteps. Note that
269      this field is only used when `return_sequences` is True and mask is
270      provided. It can useful if you want to reuse the raw output sequence of
271      the RNN without interference from the masked timesteps, eg, merging
272      bidirectional RNNs.
273
274  Call arguments:
275    inputs: Input tensor.
276    mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether
277      a given timestep should be masked. An individual `True` entry indicates
278      that the corresponding timestep should be utilized, while a `False`
279      entry indicates that the corresponding timestep should be ignored.
280    training: Python boolean indicating whether the layer should behave in
281      training mode or in inference mode. This argument is passed to the cell
282      when calling it. This is for use with cells that use dropout.
283    initial_state: List of initial state tensors to be passed to the first
284      call of the cell.
285    constants: List of constant tensors to be passed to the cell at each
286      timestep.
287
288  Input shape:
289    N-D tensor with shape `[batch_size, timesteps, ...]` or
290    `[timesteps, batch_size, ...]` when time_major is True.
291
292  Output shape:
293    - If `return_state`: a list of tensors. The first tensor is
294      the output. The remaining tensors are the last states,
295      each with shape `[batch_size, state_size]`, where `state_size` could
296      be a high dimension tensor shape.
297    - If `return_sequences`: N-D tensor with shape
298      `[batch_size, timesteps, output_size]`, where `output_size` could
299      be a high dimension tensor shape, or
300      `[timesteps, batch_size, output_size]` when `time_major` is True.
301    - Else, N-D tensor with shape `[batch_size, output_size]`, where
302      `output_size` could be a high dimension tensor shape.
303
304  Masking:
305    This layer supports masking for input data with a variable number
306    of timesteps. To introduce masks to your data,
307    use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter
308    set to `True`.
309
310  Note on using statefulness in RNNs:
311    You can set RNN layers to be 'stateful', which means that the states
312    computed for the samples in one batch will be reused as initial states
313    for the samples in the next batch. This assumes a one-to-one mapping
314    between samples in different successive batches.
315
316    To enable statefulness:
317      - Specify `stateful=True` in the layer constructor.
318      - Specify a fixed batch size for your model, by passing
319        If sequential model:
320          `batch_input_shape=(...)` to the first layer in your model.
321        Else for functional model with 1 or more Input layers:
322          `batch_shape=(...)` to all the first layers in your model.
323        This is the expected shape of your inputs
324        *including the batch size*.
325        It should be a tuple of integers, e.g. `(32, 10, 100)`.
326      - Specify `shuffle=False` when calling `fit()`.
327
328    To reset the states of your model, call `.reset_states()` on either
329    a specific layer, or on your entire model.
330
331  Note on specifying the initial state of RNNs:
332    You can specify the initial state of RNN layers symbolically by
333    calling them with the keyword argument `initial_state`. The value of
334    `initial_state` should be a tensor or list of tensors representing
335    the initial state of the RNN layer.
336
337    You can specify the initial state of RNN layers numerically by
338    calling `reset_states` with the keyword argument `states`. The value of
339    `states` should be a numpy array or list of numpy arrays representing
340    the initial state of the RNN layer.
341
342  Note on passing external constants to RNNs:
343    You can pass "external" constants to the cell using the `constants`
344    keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
345    requires that the `cell.call` method accepts the same keyword argument
346    `constants`. Such constants can be used to condition the cell
347    transformation on additional static inputs (not changing over time),
348    a.k.a. an attention mechanism.
349
350  Examples:
351
352  ```python
353  # First, let's define a RNN Cell, as a layer subclass.
354
355  class MinimalRNNCell(keras.layers.Layer):
356
357      def __init__(self, units, **kwargs):
358          self.units = units
359          self.state_size = units
360          super(MinimalRNNCell, self).__init__(**kwargs)
361
362      def build(self, input_shape):
363          self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
364                                        initializer='uniform',
365                                        name='kernel')
366          self.recurrent_kernel = self.add_weight(
367              shape=(self.units, self.units),
368              initializer='uniform',
369              name='recurrent_kernel')
370          self.built = True
371
372      def call(self, inputs, states):
373          prev_output = states[0]
374          h = K.dot(inputs, self.kernel)
375          output = h + K.dot(prev_output, self.recurrent_kernel)
376          return output, [output]
377
378  # Let's use this cell in a RNN layer:
379
380  cell = MinimalRNNCell(32)
381  x = keras.Input((None, 5))
382  layer = RNN(cell)
383  y = layer(x)
384
385  # Here's how to use the cell to build a stacked RNN:
386
387  cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
388  x = keras.Input((None, 5))
389  layer = RNN(cells)
390  y = layer(x)
391  ```
392  """
393
394  def __init__(self,
395               cell,
396               return_sequences=False,
397               return_state=False,
398               go_backwards=False,
399               stateful=False,
400               unroll=False,
401               time_major=False,
402               **kwargs):
403    if isinstance(cell, (list, tuple)):
404      cell = StackedRNNCells(cell)
405    if not 'call' in dir(cell):
406      raise ValueError('`cell` should have a `call` method. '
407                       'The RNN was passed:', cell)
408    if not 'state_size' in dir(cell):
409      raise ValueError('The RNN cell should have '
410                       'an attribute `state_size` '
411                       '(tuple of integers, '
412                       'one integer per RNN state).')
413    # If True, the output for masked timestep will be zeros, whereas in the
414    # False case, output from previous timestep is returned for masked timestep.
415    self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
416
417    if 'input_shape' not in kwargs and (
418        'input_dim' in kwargs or 'input_length' in kwargs):
419      input_shape = (kwargs.pop('input_length', None),
420                     kwargs.pop('input_dim', None))
421      kwargs['input_shape'] = input_shape
422
423    super(RNN, self).__init__(**kwargs)
424    self.cell = cell
425    self.return_sequences = return_sequences
426    self.return_state = return_state
427    self.go_backwards = go_backwards
428    self.stateful = stateful
429    self.unroll = unroll
430    self.time_major = time_major
431
432    self.supports_masking = True
433    # The input shape is unknown yet, it could have nested tensor inputs, and
434    # the input spec will be the list of specs for nested inputs, the structure
435    # of the input_spec will be the same as the input.
436    self.input_spec = None
437    self.state_spec = None
438    self._states = None
439    self.constants_spec = None
440    self._num_constants = 0
441
442    if stateful:
443      if ds_context.has_strategy():
444        raise ValueError('RNNs with stateful=True not yet supported with '
445                         'tf.distribute.Strategy.')
446
447  @property
448  def _use_input_spec_as_call_signature(self):
449    if self.unroll:
450      # When the RNN layer is unrolled, the time step shape cannot be unknown.
451      # The input spec does not define the time step (because this layer can be
452      # called with any time step value, as long as it is not None), so it
453      # cannot be used as the call function signature when saving to SavedModel.
454      return False
455    return super(RNN, self)._use_input_spec_as_call_signature
456
457  @property
458  def states(self):
459    if self._states is None:
460      state = nest.map_structure(lambda _: None, self.cell.state_size)
461      return state if nest.is_nested(self.cell.state_size) else [state]
462    return self._states
463
464  @states.setter
465  # Automatic tracking catches "self._states" which adds an extra weight and
466  # breaks HDF5 checkpoints.
467  @trackable.no_automatic_dependency_tracking
468  def states(self, states):
469    self._states = states
470
471  def compute_output_shape(self, input_shape):
472    if isinstance(input_shape, list):
473      input_shape = input_shape[0]
474    # Check whether the input shape contains any nested shapes. It could be
475    # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
476    # inputs.
477    try:
478      input_shape = tensor_shape.TensorShape(input_shape)
479    except (ValueError, TypeError):
480      # A nested tensor input
481      input_shape = nest.flatten(input_shape)[0]
482
483    batch = input_shape[0]
484    time_step = input_shape[1]
485    if self.time_major:
486      batch, time_step = time_step, batch
487
488    if _is_multiple_state(self.cell.state_size):
489      state_size = self.cell.state_size
490    else:
491      state_size = [self.cell.state_size]
492
493    def _get_output_shape(flat_output_size):
494      output_dim = tensor_shape.TensorShape(flat_output_size).as_list()
495      if self.return_sequences:
496        if self.time_major:
497          output_shape = tensor_shape.TensorShape(
498              [time_step, batch] + output_dim)
499        else:
500          output_shape = tensor_shape.TensorShape(
501              [batch, time_step] + output_dim)
502      else:
503        output_shape = tensor_shape.TensorShape([batch] + output_dim)
504      return output_shape
505
506    if getattr(self.cell, 'output_size', None) is not None:
507      # cell.output_size could be nested structure.
508      output_shape = nest.flatten(nest.map_structure(
509          _get_output_shape, self.cell.output_size))
510      output_shape = output_shape[0] if len(output_shape) == 1 else output_shape
511    else:
512      # Note that state_size[0] could be a tensor_shape or int.
513      output_shape = _get_output_shape(state_size[0])
514
515    if self.return_state:
516      def _get_state_shape(flat_state):
517        state_shape = [batch] + tensor_shape.TensorShape(flat_state).as_list()
518        return tensor_shape.TensorShape(state_shape)
519      state_shape = nest.map_structure(_get_state_shape, state_size)
520      return generic_utils.to_list(output_shape) + nest.flatten(state_shape)
521    else:
522      return output_shape
523
524  def compute_mask(self, inputs, mask):
525    # Time step masks must be the same for each input.
526    # This is because the mask for an RNN is of size [batch, time_steps, 1],
527    # and specifies which time steps should be skipped, and a time step
528    # must be skipped for all inputs.
529    # TODO(scottzhu): Should we accept multiple different masks?
530    mask = nest.flatten(mask)[0]
531    output_mask = mask if self.return_sequences else None
532    if self.return_state:
533      state_mask = [None for _ in self.states]
534      return [output_mask] + state_mask
535    else:
536      return output_mask
537
538  def build(self, input_shape):
539    if isinstance(input_shape, list):
540      input_shape = input_shape[0]
541      # The input_shape here could be a nest structure.
542
543    # do the tensor_shape to shapes here. The input could be single tensor, or a
544    # nested structure of tensors.
545    def get_input_spec(shape):
546      """Convert input shape to InputSpec."""
547      if isinstance(shape, tensor_shape.TensorShape):
548        input_spec_shape = shape.as_list()
549      else:
550        input_spec_shape = list(shape)
551      batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
552      if not self.stateful:
553        input_spec_shape[batch_index] = None
554      input_spec_shape[time_step_index] = None
555      return InputSpec(shape=tuple(input_spec_shape))
556
557    def get_step_input_shape(shape):
558      if isinstance(shape, tensor_shape.TensorShape):
559        shape = tuple(shape.as_list())
560      # remove the timestep from the input_shape
561      return shape[1:] if self.time_major else (shape[0],) + shape[2:]
562
563    # Check whether the input shape contains any nested shapes. It could be
564    # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
565    # inputs.
566    try:
567      input_shape = tensor_shape.TensorShape(input_shape)
568    except (ValueError, TypeError):
569      # A nested tensor input
570      pass
571
572    if not nest.is_nested(input_shape):
573      # This indicates the there is only one input.
574      if self.input_spec is not None:
575        self.input_spec[0] = get_input_spec(input_shape)
576      else:
577        self.input_spec = [get_input_spec(input_shape)]
578      step_input_shape = get_step_input_shape(input_shape)
579    else:
580      if self.input_spec is not None:
581        self.input_spec[0] = nest.map_structure(get_input_spec, input_shape)
582      else:
583        self.input_spec = generic_utils.to_list(
584            nest.map_structure(get_input_spec, input_shape))
585      step_input_shape = nest.map_structure(get_step_input_shape, input_shape)
586
587    # allow cell (if layer) to build before we set or validate state_spec.
588    if isinstance(self.cell, Layer) and not self.cell.built:
589      with K.name_scope(self.cell.name):
590        self.cell.build(step_input_shape)
591        self.cell.built = True
592
593    # set or validate state_spec
594    if _is_multiple_state(self.cell.state_size):
595      state_size = list(self.cell.state_size)
596    else:
597      state_size = [self.cell.state_size]
598
599    if self.state_spec is not None:
600      # initial_state was passed in call, check compatibility
601      self._validate_state_spec(state_size, self.state_spec)
602    else:
603      self.state_spec = [
604          InputSpec(shape=[None] + tensor_shape.TensorShape(dim).as_list())
605          for dim in state_size
606      ]
607    if self.stateful:
608      self.reset_states()
609    self.built = True
610
611  @staticmethod
612  def _validate_state_spec(cell_state_sizes, init_state_specs):
613    """Validate the state spec between the initial_state and the state_size.
614
615    Args:
616      cell_state_sizes: list, the `state_size` attribute from the cell.
617      init_state_specs: list, the `state_spec` from the initial_state that is
618        passed in `call()`.
619
620    Raises:
621      ValueError: When initial state spec is not compatible with the state size.
622    """
623    validation_error = ValueError(
624        'An `initial_state` was passed that is not compatible with '
625        '`cell.state_size`. Received `state_spec`={}; '
626        'however `cell.state_size` is '
627        '{}'.format(init_state_specs, cell_state_sizes))
628    flat_cell_state_sizes = nest.flatten(cell_state_sizes)
629    flat_state_specs = nest.flatten(init_state_specs)
630
631    if len(flat_cell_state_sizes) != len(flat_state_specs):
632      raise validation_error
633    for cell_state_spec, cell_state_size in zip(flat_state_specs,
634                                                flat_cell_state_sizes):
635      if not tensor_shape.TensorShape(
636          # Ignore the first axis for init_state which is for batch
637          cell_state_spec.shape[1:]).is_compatible_with(
638              tensor_shape.TensorShape(cell_state_size)):
639        raise validation_error
640
641  @doc_controls.do_not_doc_inheritable
642  def get_initial_state(self, inputs):
643    get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
644
645    if nest.is_nested(inputs):
646      # The input are nested sequences. Use the first element in the seq to get
647      # batch size and dtype.
648      inputs = nest.flatten(inputs)[0]
649
650    input_shape = array_ops.shape(inputs)
651    batch_size = input_shape[1] if self.time_major else input_shape[0]
652    dtype = inputs.dtype
653    if get_initial_state_fn:
654      init_state = get_initial_state_fn(
655          inputs=None, batch_size=batch_size, dtype=dtype)
656    else:
657      init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
658                                               dtype)
659    # Keras RNN expect the states in a list, even if it's a single state tensor.
660    if not nest.is_nested(init_state):
661      init_state = [init_state]
662    # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.
663    return list(init_state)
664
665  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
666    inputs, initial_state, constants = _standardize_args(inputs,
667                                                         initial_state,
668                                                         constants,
669                                                         self._num_constants)
670
671    if initial_state is None and constants is None:
672      return super(RNN, self).__call__(inputs, **kwargs)
673
674    # If any of `initial_state` or `constants` are specified and are Keras
675    # tensors, then add them to the inputs and temporarily modify the
676    # input_spec to include them.
677
678    additional_inputs = []
679    additional_specs = []
680    if initial_state is not None:
681      additional_inputs += initial_state
682      self.state_spec = nest.map_structure(
683          lambda s: InputSpec(shape=K.int_shape(s)), initial_state)
684      additional_specs += self.state_spec
685    if constants is not None:
686      additional_inputs += constants
687      self.constants_spec = [
688          InputSpec(shape=K.int_shape(constant)) for constant in constants
689      ]
690      self._num_constants = len(constants)
691      additional_specs += self.constants_spec
692    # additional_inputs can be empty if initial_state or constants are provided
693    # but empty (e.g. the cell is stateless).
694    flat_additional_inputs = nest.flatten(additional_inputs)
695    is_keras_tensor = K.is_keras_tensor(
696        flat_additional_inputs[0]) if flat_additional_inputs else True
697    for tensor in flat_additional_inputs:
698      if K.is_keras_tensor(tensor) != is_keras_tensor:
699        raise ValueError('The initial state or constants of an RNN'
700                         ' layer cannot be specified with a mix of'
701                         ' Keras tensors and non-Keras tensors'
702                         ' (a "Keras tensor" is a tensor that was'
703                         ' returned by a Keras layer, or by `Input`)')
704
705    if is_keras_tensor:
706      # Compute the full input spec, including state and constants
707      full_input = [inputs] + additional_inputs
708      if self.built:
709        # Keep the input_spec since it has been populated in build() method.
710        full_input_spec = self.input_spec + additional_specs
711      else:
712        # The original input_spec is None since there could be a nested tensor
713        # input. Update the input_spec to match the inputs.
714        full_input_spec = generic_utils.to_list(
715            nest.map_structure(lambda _: None, inputs)) + additional_specs
716      # Perform the call with temporarily replaced input_spec
717      self.input_spec = full_input_spec
718      output = super(RNN, self).__call__(full_input, **kwargs)
719      # Remove the additional_specs from input spec and keep the rest. It is
720      # important to keep since the input spec was populated by build(), and
721      # will be reused in the stateful=True.
722      self.input_spec = self.input_spec[:-len(additional_specs)]
723      return output
724    else:
725      if initial_state is not None:
726        kwargs['initial_state'] = initial_state
727      if constants is not None:
728        kwargs['constants'] = constants
729      return super(RNN, self).__call__(inputs, **kwargs)
730
731  def call(self,
732           inputs,
733           mask=None,
734           training=None,
735           initial_state=None,
736           constants=None):
737    # The input should be dense, padded with zeros. If a ragged input is fed
738    # into the layer, it is padded and the row lengths are used for masking.
739    inputs, row_lengths = K.convert_inputs_if_ragged(inputs)
740    is_ragged_input = (row_lengths is not None)
741    self._validate_args_if_ragged(is_ragged_input, mask)
742
743    inputs, initial_state, constants = self._process_inputs(
744        inputs, initial_state, constants)
745
746    self._maybe_reset_cell_dropout_mask(self.cell)
747    if isinstance(self.cell, StackedRNNCells):
748      for cell in self.cell.cells:
749        self._maybe_reset_cell_dropout_mask(cell)
750
751    if mask is not None:
752      # Time step masks must be the same for each input.
753      # TODO(scottzhu): Should we accept multiple different masks?
754      mask = nest.flatten(mask)[0]
755
756    if nest.is_nested(inputs):
757      # In the case of nested input, use the first element for shape check.
758      input_shape = K.int_shape(nest.flatten(inputs)[0])
759    else:
760      input_shape = K.int_shape(inputs)
761    timesteps = input_shape[0] if self.time_major else input_shape[1]
762    if self.unroll and timesteps is None:
763      raise ValueError('Cannot unroll a RNN if the '
764                       'time dimension is undefined. \n'
765                       '- If using a Sequential model, '
766                       'specify the time dimension by passing '
767                       'an `input_shape` or `batch_input_shape` '
768                       'argument to your first layer. If your '
769                       'first layer is an Embedding, you can '
770                       'also use the `input_length` argument.\n'
771                       '- If using the functional API, specify '
772                       'the time dimension by passing a `shape` '
773                       'or `batch_shape` argument to your Input layer.')
774
775    kwargs = {}
776    if generic_utils.has_arg(self.cell.call, 'training'):
777      kwargs['training'] = training
778
779    # TF RNN cells expect single tensor as state instead of list wrapped tensor.
780    is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
781    # Use the __call__ function for callable objects, eg layers, so that it
782    # will have the proper name scopes for the ops, etc.
783    cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call
784    if constants:
785      if not generic_utils.has_arg(self.cell.call, 'constants'):
786        raise ValueError('RNN cell does not support constants')
787
788      def step(inputs, states):
789        constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
790        states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
791
792        states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
793        output, new_states = cell_call_fn(
794            inputs, states, constants=constants, **kwargs)
795        if not nest.is_nested(new_states):
796          new_states = [new_states]
797        return output, new_states
798    else:
799
800      def step(inputs, states):
801        states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
802        output, new_states = cell_call_fn(inputs, states, **kwargs)
803        if not nest.is_nested(new_states):
804          new_states = [new_states]
805        return output, new_states
806    last_output, outputs, states = K.rnn(
807        step,
808        inputs,
809        initial_state,
810        constants=constants,
811        go_backwards=self.go_backwards,
812        mask=mask,
813        unroll=self.unroll,
814        input_length=row_lengths if row_lengths is not None else timesteps,
815        time_major=self.time_major,
816        zero_output_for_mask=self.zero_output_for_mask)
817
818    if self.stateful:
819      updates = [
820          state_ops.assign(self_state, state) for self_state, state in zip(
821              nest.flatten(self.states), nest.flatten(states))
822      ]
823      self.add_update(updates)
824
825    if self.return_sequences:
826      output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths)
827    else:
828      output = last_output
829
830    if self.return_state:
831      if not isinstance(states, (list, tuple)):
832        states = [states]
833      else:
834        states = list(states)
835      return generic_utils.to_list(output) + states
836    else:
837      return output
838
839  def _process_inputs(self, inputs, initial_state, constants):
840    # input shape: `(samples, time (padded with zeros), input_dim)`
841    # note that the .build() method of subclasses MUST define
842    # self.input_spec and self.state_spec with complete input shapes.
843    if (isinstance(inputs, collections.abc.Sequence)
844        and not isinstance(inputs, tuple)):
845      # get initial_state from full input spec
846      # as they could be copied to multiple GPU.
847      if not self._num_constants:
848        initial_state = inputs[1:]
849      else:
850        initial_state = inputs[1:-self._num_constants]
851        constants = inputs[-self._num_constants:]
852      if len(initial_state) == 0:
853        initial_state = None
854      inputs = inputs[0]
855
856    if self.stateful:
857      if initial_state is not None:
858        # When layer is stateful and initial_state is provided, check if the
859        # recorded state is same as the default value (zeros). Use the recorded
860        # state if it is not same as the default.
861        non_zero_count = math_ops.add_n([math_ops.count_nonzero_v2(s)
862                                         for s in nest.flatten(self.states)])
863        # Set strict = True to keep the original structure of the state.
864        initial_state = control_flow_ops.cond(non_zero_count > 0,
865                                              true_fn=lambda: self.states,
866                                              false_fn=lambda: initial_state,
867                                              strict=True)
868      else:
869        initial_state = self.states
870    elif initial_state is None:
871      initial_state = self.get_initial_state(inputs)
872
873    if len(initial_state) != len(self.states):
874      raise ValueError('Layer has ' + str(len(self.states)) +
875                       ' states but was passed ' + str(len(initial_state)) +
876                       ' initial states.')
877    return inputs, initial_state, constants
878
879  def _validate_args_if_ragged(self, is_ragged_input, mask):
880    if not is_ragged_input:
881      return
882
883    if mask is not None:
884      raise ValueError('The mask that was passed in was ' + str(mask) +
885                       ' and cannot be applied to RaggedTensor inputs. Please '
886                       'make sure that there is no mask passed in by upstream '
887                       'layers.')
888    if self.unroll:
889      raise ValueError('The input received contains RaggedTensors and does '
890                       'not support unrolling. Disable unrolling by passing '
891                       '`unroll=False` in the RNN Layer constructor.')
892
893  def _maybe_reset_cell_dropout_mask(self, cell):
894    if isinstance(cell, DropoutRNNCellMixin):
895      cell.reset_dropout_mask()
896      cell.reset_recurrent_dropout_mask()
897
898  def reset_states(self, states=None):
899    """Reset the recorded states for the stateful RNN layer.
900
901    Can only be used when RNN layer is constructed with `stateful` = `True`.
902    Args:
903      states: Numpy arrays that contains the value for the initial state, which
904        will be feed to cell at the first time step. When the value is None,
905        zero filled numpy array will be created based on the cell state size.
906
907    Raises:
908      AttributeError: When the RNN layer is not stateful.
909      ValueError: When the batch size of the RNN layer is unknown.
910      ValueError: When the input numpy array is not compatible with the RNN
911        layer state, either size wise or dtype wise.
912    """
913    if not self.stateful:
914      raise AttributeError('Layer must be stateful.')
915    spec_shape = None
916    if self.input_spec is not None:
917      spec_shape = nest.flatten(self.input_spec[0])[0].shape
918    if spec_shape is None:
919      # It is possible to have spec shape to be None, eg when construct a RNN
920      # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know
921      # it has 3 dim input, but not its full shape spec before build().
922      batch_size = None
923    else:
924      batch_size = spec_shape[1] if self.time_major else spec_shape[0]
925    if not batch_size:
926      raise ValueError('If a RNN is stateful, it needs to know '
927                       'its batch size. Specify the batch size '
928                       'of your input tensors: \n'
929                       '- If using a Sequential model, '
930                       'specify the batch size by passing '
931                       'a `batch_input_shape` '
932                       'argument to your first layer.\n'
933                       '- If using the functional API, specify '
934                       'the batch size by passing a '
935                       '`batch_shape` argument to your Input layer.')
936    # initialize state if None
937    if nest.flatten(self.states)[0] is None:
938      if getattr(self.cell, 'get_initial_state', None):
939        flat_init_state_values = nest.flatten(self.cell.get_initial_state(
940            inputs=None, batch_size=batch_size,
941            dtype=self.dtype or K.floatx()))
942      else:
943        flat_init_state_values = nest.flatten(_generate_zero_filled_state(
944            batch_size, self.cell.state_size, self.dtype or K.floatx()))
945      flat_states_variables = nest.map_structure(
946          K.variable, flat_init_state_values)
947      self.states = nest.pack_sequence_as(self.cell.state_size,
948                                          flat_states_variables)
949      if not nest.is_nested(self.states):
950        self.states = [self.states]
951    elif states is None:
952      for state, size in zip(nest.flatten(self.states),
953                             nest.flatten(self.cell.state_size)):
954        K.set_value(state, np.zeros([batch_size] +
955                                    tensor_shape.TensorShape(size).as_list()))
956    else:
957      flat_states = nest.flatten(self.states)
958      flat_input_states = nest.flatten(states)
959      if len(flat_input_states) != len(flat_states):
960        raise ValueError('Layer ' + self.name + ' expects ' +
961                         str(len(flat_states)) + ' states, '
962                         'but it received ' + str(len(flat_input_states)) +
963                         ' state values. Input received: ' + str(states))
964      set_value_tuples = []
965      for i, (value, state) in enumerate(zip(flat_input_states,
966                                             flat_states)):
967        if value.shape != state.shape:
968          raise ValueError(
969              'State ' + str(i) + ' is incompatible with layer ' +
970              self.name + ': expected shape=' + str(
971                  (batch_size, state)) + ', found shape=' + str(value.shape))
972        set_value_tuples.append((state, value))
973      K.batch_set_value(set_value_tuples)
974
975  def get_config(self):
976    config = {
977        'return_sequences': self.return_sequences,
978        'return_state': self.return_state,
979        'go_backwards': self.go_backwards,
980        'stateful': self.stateful,
981        'unroll': self.unroll,
982        'time_major': self.time_major
983    }
984    if self._num_constants:
985      config['num_constants'] = self._num_constants
986    if self.zero_output_for_mask:
987      config['zero_output_for_mask'] = self.zero_output_for_mask
988
989    config['cell'] = generic_utils.serialize_keras_object(self.cell)
990    base_config = super(RNN, self).get_config()
991    return dict(list(base_config.items()) + list(config.items()))
992
993  @classmethod
994  def from_config(cls, config, custom_objects=None):
995    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
996    cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
997    num_constants = config.pop('num_constants', 0)
998    layer = cls(cell, **config)
999    layer._num_constants = num_constants
1000    return layer
1001
1002  @property
1003  def _trackable_saved_model_saver(self):
1004    return layer_serialization.RNNSavedModelSaver(self)
1005
1006
1007@keras_export('keras.layers.AbstractRNNCell')
1008class AbstractRNNCell(Layer):
1009  """Abstract object representing an RNN cell.
1010
1011  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
1012  for details about the usage of RNN API.
1013
1014  This is the base class for implementing RNN cells with custom behavior.
1015
1016  Every `RNNCell` must have the properties below and implement `call` with
1017  the signature `(output, next_state) = call(input, state)`.
1018
1019  Examples:
1020
1021  ```python
1022    class MinimalRNNCell(AbstractRNNCell):
1023
1024      def __init__(self, units, **kwargs):
1025        self.units = units
1026        super(MinimalRNNCell, self).__init__(**kwargs)
1027
1028      @property
1029      def state_size(self):
1030        return self.units
1031
1032      def build(self, input_shape):
1033        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
1034                                      initializer='uniform',
1035                                      name='kernel')
1036        self.recurrent_kernel = self.add_weight(
1037            shape=(self.units, self.units),
1038            initializer='uniform',
1039            name='recurrent_kernel')
1040        self.built = True
1041
1042      def call(self, inputs, states):
1043        prev_output = states[0]
1044        h = K.dot(inputs, self.kernel)
1045        output = h + K.dot(prev_output, self.recurrent_kernel)
1046        return output, output
1047  ```
1048
1049  This definition of cell differs from the definition used in the literature.
1050  In the literature, 'cell' refers to an object with a single scalar output.
1051  This definition refers to a horizontal array of such units.
1052
1053  An RNN cell, in the most abstract setting, is anything that has
1054  a state and performs some operation that takes a matrix of inputs.
1055  This operation results in an output matrix with `self.output_size` columns.
1056  If `self.state_size` is an integer, this operation also results in a new
1057  state matrix with `self.state_size` columns.  If `self.state_size` is a
1058  (possibly nested tuple of) TensorShape object(s), then it should return a
1059  matching structure of Tensors having shape `[batch_size].concatenate(s)`
1060  for each `s` in `self.batch_size`.
1061  """
1062
1063  def call(self, inputs, states):
1064    """The function that contains the logic for one RNN step calculation.
1065
1066    Args:
1067      inputs: the input tensor, which is a slide from the overall RNN input by
1068        the time dimension (usually the second dimension).
1069      states: the state tensor from previous step, which has the same shape
1070        as `(batch, state_size)`. In the case of timestep 0, it will be the
1071        initial state user specified, or zero filled tensor otherwise.
1072
1073    Returns:
1074      A tuple of two tensors:
1075        1. output tensor for the current timestep, with size `output_size`.
1076        2. state tensor for next step, which has the shape of `state_size`.
1077    """
1078    raise NotImplementedError('Abstract method')
1079
1080  @property
1081  def state_size(self):
1082    """size(s) of state(s) used by this cell.
1083
1084    It can be represented by an Integer, a TensorShape or a tuple of Integers
1085    or TensorShapes.
1086    """
1087    raise NotImplementedError('Abstract method')
1088
1089  @property
1090  def output_size(self):
1091    """Integer or TensorShape: size of outputs produced by this cell."""
1092    raise NotImplementedError('Abstract method')
1093
1094  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1095    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1096
1097
1098@doc_controls.do_not_generate_docs
1099class DropoutRNNCellMixin(object):
1100  """Object that hold dropout related fields for RNN Cell.
1101
1102  This class is not a standalone RNN cell. It suppose to be used with a RNN cell
1103  by multiple inheritance. Any cell that mix with class should have following
1104  fields:
1105    dropout: a float number within range [0, 1). The ratio that the input
1106      tensor need to dropout.
1107    recurrent_dropout: a float number within range [0, 1). The ratio that the
1108      recurrent state weights need to dropout.
1109  This object will create and cache created dropout masks, and reuse them for
1110  the incoming data, so that the same mask is used for every batch input.
1111  """
1112
1113  def __init__(self, *args, **kwargs):
1114    self._create_non_trackable_mask_cache()
1115    super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
1116
1117  @trackable.no_automatic_dependency_tracking
1118  def _create_non_trackable_mask_cache(self):
1119    """Create the cache for dropout and recurrent dropout mask.
1120
1121    Note that the following two masks will be used in "graph function" mode,
1122    e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
1123    tensors will be generated differently than in the "graph function" case,
1124    and they will be cached.
1125
1126    Also note that in graph mode, we still cache those masks only because the
1127    RNN could be created with `unroll=True`. In that case, the `cell.call()`
1128    function will be invoked multiple times, and we want to ensure same mask
1129    is used every time.
1130
1131    Also the caches are created without tracking. Since they are not picklable
1132    by python when deepcopy, we don't want `layer._obj_reference_counts_dict`
1133    to track it by default.
1134    """
1135    self._dropout_mask_cache = K.ContextValueCache(self._create_dropout_mask)
1136    self._recurrent_dropout_mask_cache = K.ContextValueCache(
1137        self._create_recurrent_dropout_mask)
1138
1139  def reset_dropout_mask(self):
1140    """Reset the cached dropout masks if any.
1141
1142    This is important for the RNN layer to invoke this in it `call()` method so
1143    that the cached mask is cleared before calling the `cell.call()`. The mask
1144    should be cached across the timestep within the same batch, but shouldn't
1145    be cached between batches. Otherwise it will introduce unreasonable bias
1146    against certain index of data within the batch.
1147    """
1148    self._dropout_mask_cache.clear()
1149
1150  def reset_recurrent_dropout_mask(self):
1151    """Reset the cached recurrent dropout masks if any.
1152
1153    This is important for the RNN layer to invoke this in it call() method so
1154    that the cached mask is cleared before calling the cell.call(). The mask
1155    should be cached across the timestep within the same batch, but shouldn't
1156    be cached between batches. Otherwise it will introduce unreasonable bias
1157    against certain index of data within the batch.
1158    """
1159    self._recurrent_dropout_mask_cache.clear()
1160
1161  def _create_dropout_mask(self, inputs, training, count=1):
1162    return _generate_dropout_mask(
1163        array_ops.ones_like(inputs),
1164        self.dropout,
1165        training=training,
1166        count=count)
1167
1168  def _create_recurrent_dropout_mask(self, inputs, training, count=1):
1169    return _generate_dropout_mask(
1170        array_ops.ones_like(inputs),
1171        self.recurrent_dropout,
1172        training=training,
1173        count=count)
1174
1175  def get_dropout_mask_for_cell(self, inputs, training, count=1):
1176    """Get the dropout mask for RNN cell's input.
1177
1178    It will create mask based on context if there isn't any existing cached
1179    mask. If a new mask is generated, it will update the cache in the cell.
1180
1181    Args:
1182      inputs: The input tensor whose shape will be used to generate dropout
1183        mask.
1184      training: Boolean tensor, whether its in training mode, dropout will be
1185        ignored in non-training mode.
1186      count: Int, how many dropout mask will be generated. It is useful for cell
1187        that has internal weights fused together.
1188    Returns:
1189      List of mask tensor, generated or cached mask based on context.
1190    """
1191    if self.dropout == 0:
1192      return None
1193    init_kwargs = dict(inputs=inputs, training=training, count=count)
1194    return self._dropout_mask_cache.setdefault(kwargs=init_kwargs)
1195
1196  def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
1197    """Get the recurrent dropout mask for RNN cell.
1198
1199    It will create mask based on context if there isn't any existing cached
1200    mask. If a new mask is generated, it will update the cache in the cell.
1201
1202    Args:
1203      inputs: The input tensor whose shape will be used to generate dropout
1204        mask.
1205      training: Boolean tensor, whether its in training mode, dropout will be
1206        ignored in non-training mode.
1207      count: Int, how many dropout mask will be generated. It is useful for cell
1208        that has internal weights fused together.
1209    Returns:
1210      List of mask tensor, generated or cached mask based on context.
1211    """
1212    if self.recurrent_dropout == 0:
1213      return None
1214    init_kwargs = dict(inputs=inputs, training=training, count=count)
1215    return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
1216
1217  def __getstate__(self):
1218    # Used for deepcopy. The caching can't be pickled by python, since it will
1219    # contain tensor and graph.
1220    state = super(DropoutRNNCellMixin, self).__getstate__()
1221    state.pop('_dropout_mask_cache', None)
1222    state.pop('_recurrent_dropout_mask_cache', None)
1223    return state
1224
1225  def __setstate__(self, state):
1226    state['_dropout_mask_cache'] = K.ContextValueCache(
1227        self._create_dropout_mask)
1228    state['_recurrent_dropout_mask_cache'] = K.ContextValueCache(
1229        self._create_recurrent_dropout_mask)
1230    super(DropoutRNNCellMixin, self).__setstate__(state)
1231
1232
1233@keras_export('keras.layers.SimpleRNNCell')
1234class SimpleRNNCell(DropoutRNNCellMixin, Layer):
1235  """Cell class for SimpleRNN.
1236
1237  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
1238  for details about the usage of RNN API.
1239
1240  This class processes one step within the whole time sequence input, whereas
1241  `tf.keras.layer.SimpleRNN` processes the whole sequence.
1242
1243  Args:
1244    units: Positive integer, dimensionality of the output space.
1245    activation: Activation function to use.
1246      Default: hyperbolic tangent (`tanh`).
1247      If you pass `None`, no activation is applied
1248      (ie. "linear" activation: `a(x) = x`).
1249    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
1250    kernel_initializer: Initializer for the `kernel` weights matrix,
1251      used for the linear transformation of the inputs. Default:
1252      `glorot_uniform`.
1253    recurrent_initializer: Initializer for the `recurrent_kernel`
1254      weights matrix, used for the linear transformation of the recurrent state.
1255      Default: `orthogonal`.
1256    bias_initializer: Initializer for the bias vector. Default: `zeros`.
1257    kernel_regularizer: Regularizer function applied to the `kernel` weights
1258      matrix. Default: `None`.
1259    recurrent_regularizer: Regularizer function applied to the
1260      `recurrent_kernel` weights matrix. Default: `None`.
1261    bias_regularizer: Regularizer function applied to the bias vector. Default:
1262      `None`.
1263    kernel_constraint: Constraint function applied to the `kernel` weights
1264      matrix. Default: `None`.
1265    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
1266      weights matrix. Default: `None`.
1267    bias_constraint: Constraint function applied to the bias vector. Default:
1268      `None`.
1269    dropout: Float between 0 and 1. Fraction of the units to drop for the linear
1270      transformation of the inputs. Default: 0.
1271    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
1272      the linear transformation of the recurrent state. Default: 0.
1273
1274  Call arguments:
1275    inputs: A 2D tensor, with shape of `[batch, feature]`.
1276    states: A 2D tensor with shape of `[batch, units]`, which is the state from
1277      the previous time step. For timestep 0, the initial state provided by user
1278      will be feed to cell.
1279    training: Python boolean indicating whether the layer should behave in
1280      training mode or in inference mode. Only relevant when `dropout` or
1281      `recurrent_dropout` is used.
1282
1283  Examples:
1284
1285  ```python
1286  inputs = np.random.random([32, 10, 8]).astype(np.float32)
1287  rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4))
1288
1289  output = rnn(inputs)  # The output has shape `[32, 4]`.
1290
1291  rnn = tf.keras.layers.RNN(
1292      tf.keras.layers.SimpleRNNCell(4),
1293      return_sequences=True,
1294      return_state=True)
1295
1296  # whole_sequence_output has shape `[32, 10, 4]`.
1297  # final_state has shape `[32, 4]`.
1298  whole_sequence_output, final_state = rnn(inputs)
1299  ```
1300  """
1301
1302  def __init__(self,
1303               units,
1304               activation='tanh',
1305               use_bias=True,
1306               kernel_initializer='glorot_uniform',
1307               recurrent_initializer='orthogonal',
1308               bias_initializer='zeros',
1309               kernel_regularizer=None,
1310               recurrent_regularizer=None,
1311               bias_regularizer=None,
1312               kernel_constraint=None,
1313               recurrent_constraint=None,
1314               bias_constraint=None,
1315               dropout=0.,
1316               recurrent_dropout=0.,
1317               **kwargs):
1318    # By default use cached variable under v2 mode, see b/143699808.
1319    if ops.executing_eagerly_outside_functions():
1320      self._enable_caching_device = kwargs.pop('enable_caching_device', True)
1321    else:
1322      self._enable_caching_device = kwargs.pop('enable_caching_device', False)
1323    super(SimpleRNNCell, self).__init__(**kwargs)
1324    self.units = units
1325    self.activation = activations.get(activation)
1326    self.use_bias = use_bias
1327
1328    self.kernel_initializer = initializers.get(kernel_initializer)
1329    self.recurrent_initializer = initializers.get(recurrent_initializer)
1330    self.bias_initializer = initializers.get(bias_initializer)
1331
1332    self.kernel_regularizer = regularizers.get(kernel_regularizer)
1333    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
1334    self.bias_regularizer = regularizers.get(bias_regularizer)
1335
1336    self.kernel_constraint = constraints.get(kernel_constraint)
1337    self.recurrent_constraint = constraints.get(recurrent_constraint)
1338    self.bias_constraint = constraints.get(bias_constraint)
1339
1340    self.dropout = min(1., max(0., dropout))
1341    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
1342    self.state_size = self.units
1343    self.output_size = self.units
1344
1345  @tf_utils.shape_type_conversion
1346  def build(self, input_shape):
1347    default_caching_device = _caching_device(self)
1348    self.kernel = self.add_weight(
1349        shape=(input_shape[-1], self.units),
1350        name='kernel',
1351        initializer=self.kernel_initializer,
1352        regularizer=self.kernel_regularizer,
1353        constraint=self.kernel_constraint,
1354        caching_device=default_caching_device)
1355    self.recurrent_kernel = self.add_weight(
1356        shape=(self.units, self.units),
1357        name='recurrent_kernel',
1358        initializer=self.recurrent_initializer,
1359        regularizer=self.recurrent_regularizer,
1360        constraint=self.recurrent_constraint,
1361        caching_device=default_caching_device)
1362    if self.use_bias:
1363      self.bias = self.add_weight(
1364          shape=(self.units,),
1365          name='bias',
1366          initializer=self.bias_initializer,
1367          regularizer=self.bias_regularizer,
1368          constraint=self.bias_constraint,
1369          caching_device=default_caching_device)
1370    else:
1371      self.bias = None
1372    self.built = True
1373
1374  def call(self, inputs, states, training=None):
1375    prev_output = states[0] if nest.is_nested(states) else states
1376    dp_mask = self.get_dropout_mask_for_cell(inputs, training)
1377    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
1378        prev_output, training)
1379
1380    if dp_mask is not None:
1381      h = K.dot(inputs * dp_mask, self.kernel)
1382    else:
1383      h = K.dot(inputs, self.kernel)
1384    if self.bias is not None:
1385      h = K.bias_add(h, self.bias)
1386
1387    if rec_dp_mask is not None:
1388      prev_output = prev_output * rec_dp_mask
1389    output = h + K.dot(prev_output, self.recurrent_kernel)
1390    if self.activation is not None:
1391      output = self.activation(output)
1392
1393    new_state = [output] if nest.is_nested(states) else output
1394    return output, new_state
1395
1396  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1397    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1398
1399  def get_config(self):
1400    config = {
1401        'units':
1402            self.units,
1403        'activation':
1404            activations.serialize(self.activation),
1405        'use_bias':
1406            self.use_bias,
1407        'kernel_initializer':
1408            initializers.serialize(self.kernel_initializer),
1409        'recurrent_initializer':
1410            initializers.serialize(self.recurrent_initializer),
1411        'bias_initializer':
1412            initializers.serialize(self.bias_initializer),
1413        'kernel_regularizer':
1414            regularizers.serialize(self.kernel_regularizer),
1415        'recurrent_regularizer':
1416            regularizers.serialize(self.recurrent_regularizer),
1417        'bias_regularizer':
1418            regularizers.serialize(self.bias_regularizer),
1419        'kernel_constraint':
1420            constraints.serialize(self.kernel_constraint),
1421        'recurrent_constraint':
1422            constraints.serialize(self.recurrent_constraint),
1423        'bias_constraint':
1424            constraints.serialize(self.bias_constraint),
1425        'dropout':
1426            self.dropout,
1427        'recurrent_dropout':
1428            self.recurrent_dropout
1429    }
1430    config.update(_config_for_enable_caching_device(self))
1431    base_config = super(SimpleRNNCell, self).get_config()
1432    return dict(list(base_config.items()) + list(config.items()))
1433
1434
1435@keras_export('keras.layers.SimpleRNN')
1436class SimpleRNN(RNN):
1437  """Fully-connected RNN where the output is to be fed back to input.
1438
1439  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
1440  for details about the usage of RNN API.
1441
1442  Args:
1443    units: Positive integer, dimensionality of the output space.
1444    activation: Activation function to use.
1445      Default: hyperbolic tangent (`tanh`).
1446      If you pass None, no activation is applied
1447      (ie. "linear" activation: `a(x) = x`).
1448    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
1449    kernel_initializer: Initializer for the `kernel` weights matrix,
1450      used for the linear transformation of the inputs. Default:
1451      `glorot_uniform`.
1452    recurrent_initializer: Initializer for the `recurrent_kernel`
1453      weights matrix, used for the linear transformation of the recurrent state.
1454      Default: `orthogonal`.
1455    bias_initializer: Initializer for the bias vector. Default: `zeros`.
1456    kernel_regularizer: Regularizer function applied to the `kernel` weights
1457      matrix. Default: `None`.
1458    recurrent_regularizer: Regularizer function applied to the
1459      `recurrent_kernel` weights matrix. Default: `None`.
1460    bias_regularizer: Regularizer function applied to the bias vector. Default:
1461      `None`.
1462    activity_regularizer: Regularizer function applied to the output of the
1463      layer (its "activation"). Default: `None`.
1464    kernel_constraint: Constraint function applied to the `kernel` weights
1465      matrix. Default: `None`.
1466    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
1467      weights matrix.  Default: `None`.
1468    bias_constraint: Constraint function applied to the bias vector. Default:
1469      `None`.
1470    dropout: Float between 0 and 1.
1471      Fraction of the units to drop for the linear transformation of the inputs.
1472      Default: 0.
1473    recurrent_dropout: Float between 0 and 1.
1474      Fraction of the units to drop for the linear transformation of the
1475      recurrent state. Default: 0.
1476    return_sequences: Boolean. Whether to return the last output
1477      in the output sequence, or the full sequence. Default: `False`.
1478    return_state: Boolean. Whether to return the last state
1479      in addition to the output. Default: `False`
1480    go_backwards: Boolean (default False).
1481      If True, process the input sequence backwards and return the
1482      reversed sequence.
1483    stateful: Boolean (default False). If True, the last state
1484      for each sample at index i in a batch will be used as initial
1485      state for the sample of index i in the following batch.
1486    unroll: Boolean (default False).
1487      If True, the network will be unrolled,
1488      else a symbolic loop will be used.
1489      Unrolling can speed-up a RNN,
1490      although it tends to be more memory-intensive.
1491      Unrolling is only suitable for short sequences.
1492
1493  Call arguments:
1494    inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
1495    mask: Binary tensor of shape `[batch, timesteps]` indicating whether
1496      a given timestep should be masked. An individual `True` entry indicates
1497      that the corresponding timestep should be utilized, while a `False` entry
1498      indicates that the corresponding timestep should be ignored.
1499    training: Python boolean indicating whether the layer should behave in
1500      training mode or in inference mode. This argument is passed to the cell
1501      when calling it. This is only relevant if `dropout` or
1502      `recurrent_dropout` is used.
1503    initial_state: List of initial state tensors to be passed to the first
1504      call of the cell.
1505
1506  Examples:
1507
1508  ```python
1509  inputs = np.random.random([32, 10, 8]).astype(np.float32)
1510  simple_rnn = tf.keras.layers.SimpleRNN(4)
1511
1512  output = simple_rnn(inputs)  # The output has shape `[32, 4]`.
1513
1514  simple_rnn = tf.keras.layers.SimpleRNN(
1515      4, return_sequences=True, return_state=True)
1516
1517  # whole_sequence_output has shape `[32, 10, 4]`.
1518  # final_state has shape `[32, 4]`.
1519  whole_sequence_output, final_state = simple_rnn(inputs)
1520  ```
1521  """
1522
1523  def __init__(self,
1524               units,
1525               activation='tanh',
1526               use_bias=True,
1527               kernel_initializer='glorot_uniform',
1528               recurrent_initializer='orthogonal',
1529               bias_initializer='zeros',
1530               kernel_regularizer=None,
1531               recurrent_regularizer=None,
1532               bias_regularizer=None,
1533               activity_regularizer=None,
1534               kernel_constraint=None,
1535               recurrent_constraint=None,
1536               bias_constraint=None,
1537               dropout=0.,
1538               recurrent_dropout=0.,
1539               return_sequences=False,
1540               return_state=False,
1541               go_backwards=False,
1542               stateful=False,
1543               unroll=False,
1544               **kwargs):
1545    if 'implementation' in kwargs:
1546      kwargs.pop('implementation')
1547      logging.warning('The `implementation` argument '
1548                      'in `SimpleRNN` has been deprecated. '
1549                      'Please remove it from your layer call.')
1550    if 'enable_caching_device' in kwargs:
1551      cell_kwargs = {'enable_caching_device':
1552                     kwargs.pop('enable_caching_device')}
1553    else:
1554      cell_kwargs = {}
1555    cell = SimpleRNNCell(
1556        units,
1557        activation=activation,
1558        use_bias=use_bias,
1559        kernel_initializer=kernel_initializer,
1560        recurrent_initializer=recurrent_initializer,
1561        bias_initializer=bias_initializer,
1562        kernel_regularizer=kernel_regularizer,
1563        recurrent_regularizer=recurrent_regularizer,
1564        bias_regularizer=bias_regularizer,
1565        kernel_constraint=kernel_constraint,
1566        recurrent_constraint=recurrent_constraint,
1567        bias_constraint=bias_constraint,
1568        dropout=dropout,
1569        recurrent_dropout=recurrent_dropout,
1570        dtype=kwargs.get('dtype'),
1571        trainable=kwargs.get('trainable', True),
1572        **cell_kwargs)
1573    super(SimpleRNN, self).__init__(
1574        cell,
1575        return_sequences=return_sequences,
1576        return_state=return_state,
1577        go_backwards=go_backwards,
1578        stateful=stateful,
1579        unroll=unroll,
1580        **kwargs)
1581    self.activity_regularizer = regularizers.get(activity_regularizer)
1582    self.input_spec = [InputSpec(ndim=3)]
1583
1584  def call(self, inputs, mask=None, training=None, initial_state=None):
1585    return super(SimpleRNN, self).call(
1586        inputs, mask=mask, training=training, initial_state=initial_state)
1587
1588  @property
1589  def units(self):
1590    return self.cell.units
1591
1592  @property
1593  def activation(self):
1594    return self.cell.activation
1595
1596  @property
1597  def use_bias(self):
1598    return self.cell.use_bias
1599
1600  @property
1601  def kernel_initializer(self):
1602    return self.cell.kernel_initializer
1603
1604  @property
1605  def recurrent_initializer(self):
1606    return self.cell.recurrent_initializer
1607
1608  @property
1609  def bias_initializer(self):
1610    return self.cell.bias_initializer
1611
1612  @property
1613  def kernel_regularizer(self):
1614    return self.cell.kernel_regularizer
1615
1616  @property
1617  def recurrent_regularizer(self):
1618    return self.cell.recurrent_regularizer
1619
1620  @property
1621  def bias_regularizer(self):
1622    return self.cell.bias_regularizer
1623
1624  @property
1625  def kernel_constraint(self):
1626    return self.cell.kernel_constraint
1627
1628  @property
1629  def recurrent_constraint(self):
1630    return self.cell.recurrent_constraint
1631
1632  @property
1633  def bias_constraint(self):
1634    return self.cell.bias_constraint
1635
1636  @property
1637  def dropout(self):
1638    return self.cell.dropout
1639
1640  @property
1641  def recurrent_dropout(self):
1642    return self.cell.recurrent_dropout
1643
1644  def get_config(self):
1645    config = {
1646        'units':
1647            self.units,
1648        'activation':
1649            activations.serialize(self.activation),
1650        'use_bias':
1651            self.use_bias,
1652        'kernel_initializer':
1653            initializers.serialize(self.kernel_initializer),
1654        'recurrent_initializer':
1655            initializers.serialize(self.recurrent_initializer),
1656        'bias_initializer':
1657            initializers.serialize(self.bias_initializer),
1658        'kernel_regularizer':
1659            regularizers.serialize(self.kernel_regularizer),
1660        'recurrent_regularizer':
1661            regularizers.serialize(self.recurrent_regularizer),
1662        'bias_regularizer':
1663            regularizers.serialize(self.bias_regularizer),
1664        'activity_regularizer':
1665            regularizers.serialize(self.activity_regularizer),
1666        'kernel_constraint':
1667            constraints.serialize(self.kernel_constraint),
1668        'recurrent_constraint':
1669            constraints.serialize(self.recurrent_constraint),
1670        'bias_constraint':
1671            constraints.serialize(self.bias_constraint),
1672        'dropout':
1673            self.dropout,
1674        'recurrent_dropout':
1675            self.recurrent_dropout
1676    }
1677    base_config = super(SimpleRNN, self).get_config()
1678    config.update(_config_for_enable_caching_device(self.cell))
1679    del base_config['cell']
1680    return dict(list(base_config.items()) + list(config.items()))
1681
1682  @classmethod
1683  def from_config(cls, config):
1684    if 'implementation' in config:
1685      config.pop('implementation')
1686    return cls(**config)
1687
1688
1689@keras_export(v1=['keras.layers.GRUCell'])
1690class GRUCell(DropoutRNNCellMixin, Layer):
1691  """Cell class for the GRU layer.
1692
1693  Args:
1694    units: Positive integer, dimensionality of the output space.
1695    activation: Activation function to use.
1696      Default: hyperbolic tangent (`tanh`).
1697      If you pass None, no activation is applied
1698      (ie. "linear" activation: `a(x) = x`).
1699    recurrent_activation: Activation function to use
1700      for the recurrent step.
1701      Default: hard sigmoid (`hard_sigmoid`).
1702      If you pass `None`, no activation is applied
1703      (ie. "linear" activation: `a(x) = x`).
1704    use_bias: Boolean, whether the layer uses a bias vector.
1705    kernel_initializer: Initializer for the `kernel` weights matrix,
1706      used for the linear transformation of the inputs.
1707    recurrent_initializer: Initializer for the `recurrent_kernel`
1708      weights matrix,
1709      used for the linear transformation of the recurrent state.
1710    bias_initializer: Initializer for the bias vector.
1711    kernel_regularizer: Regularizer function applied to
1712      the `kernel` weights matrix.
1713    recurrent_regularizer: Regularizer function applied to
1714      the `recurrent_kernel` weights matrix.
1715    bias_regularizer: Regularizer function applied to the bias vector.
1716    kernel_constraint: Constraint function applied to
1717      the `kernel` weights matrix.
1718    recurrent_constraint: Constraint function applied to
1719      the `recurrent_kernel` weights matrix.
1720    bias_constraint: Constraint function applied to the bias vector.
1721    dropout: Float between 0 and 1.
1722      Fraction of the units to drop for the linear transformation of the inputs.
1723    recurrent_dropout: Float between 0 and 1.
1724      Fraction of the units to drop for
1725      the linear transformation of the recurrent state.
1726    reset_after: GRU convention (whether to apply reset gate after or
1727      before matrix multiplication). False = "before" (default),
1728      True = "after" (CuDNN compatible).
1729
1730  Call arguments:
1731    inputs: A 2D tensor.
1732    states: List of state tensors corresponding to the previous timestep.
1733    training: Python boolean indicating whether the layer should behave in
1734      training mode or in inference mode. Only relevant when `dropout` or
1735      `recurrent_dropout` is used.
1736  """
1737
1738  def __init__(self,
1739               units,
1740               activation='tanh',
1741               recurrent_activation='hard_sigmoid',
1742               use_bias=True,
1743               kernel_initializer='glorot_uniform',
1744               recurrent_initializer='orthogonal',
1745               bias_initializer='zeros',
1746               kernel_regularizer=None,
1747               recurrent_regularizer=None,
1748               bias_regularizer=None,
1749               kernel_constraint=None,
1750               recurrent_constraint=None,
1751               bias_constraint=None,
1752               dropout=0.,
1753               recurrent_dropout=0.,
1754               reset_after=False,
1755               **kwargs):
1756    # By default use cached variable under v2 mode, see b/143699808.
1757    if ops.executing_eagerly_outside_functions():
1758      self._enable_caching_device = kwargs.pop('enable_caching_device', True)
1759    else:
1760      self._enable_caching_device = kwargs.pop('enable_caching_device', False)
1761    super(GRUCell, self).__init__(**kwargs)
1762    self.units = units
1763    self.activation = activations.get(activation)
1764    self.recurrent_activation = activations.get(recurrent_activation)
1765    self.use_bias = use_bias
1766
1767    self.kernel_initializer = initializers.get(kernel_initializer)
1768    self.recurrent_initializer = initializers.get(recurrent_initializer)
1769    self.bias_initializer = initializers.get(bias_initializer)
1770
1771    self.kernel_regularizer = regularizers.get(kernel_regularizer)
1772    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
1773    self.bias_regularizer = regularizers.get(bias_regularizer)
1774
1775    self.kernel_constraint = constraints.get(kernel_constraint)
1776    self.recurrent_constraint = constraints.get(recurrent_constraint)
1777    self.bias_constraint = constraints.get(bias_constraint)
1778
1779    self.dropout = min(1., max(0., dropout))
1780    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
1781
1782    implementation = kwargs.pop('implementation', 1)
1783    if self.recurrent_dropout != 0 and implementation != 1:
1784      logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
1785      self.implementation = 1
1786    else:
1787      self.implementation = implementation
1788    self.reset_after = reset_after
1789    self.state_size = self.units
1790    self.output_size = self.units
1791
1792  @tf_utils.shape_type_conversion
1793  def build(self, input_shape):
1794    input_dim = input_shape[-1]
1795    default_caching_device = _caching_device(self)
1796    self.kernel = self.add_weight(
1797        shape=(input_dim, self.units * 3),
1798        name='kernel',
1799        initializer=self.kernel_initializer,
1800        regularizer=self.kernel_regularizer,
1801        constraint=self.kernel_constraint,
1802        caching_device=default_caching_device)
1803    self.recurrent_kernel = self.add_weight(
1804        shape=(self.units, self.units * 3),
1805        name='recurrent_kernel',
1806        initializer=self.recurrent_initializer,
1807        regularizer=self.recurrent_regularizer,
1808        constraint=self.recurrent_constraint,
1809        caching_device=default_caching_device)
1810
1811    if self.use_bias:
1812      if not self.reset_after:
1813        bias_shape = (3 * self.units,)
1814      else:
1815        # separate biases for input and recurrent kernels
1816        # Note: the shape is intentionally different from CuDNNGRU biases
1817        # `(2 * 3 * self.units,)`, so that we can distinguish the classes
1818        # when loading and converting saved weights.
1819        bias_shape = (2, 3 * self.units)
1820      self.bias = self.add_weight(shape=bias_shape,
1821                                  name='bias',
1822                                  initializer=self.bias_initializer,
1823                                  regularizer=self.bias_regularizer,
1824                                  constraint=self.bias_constraint,
1825                                  caching_device=default_caching_device)
1826    else:
1827      self.bias = None
1828    self.built = True
1829
1830  def call(self, inputs, states, training=None):
1831    h_tm1 = states[0] if nest.is_nested(states) else states  # previous memory
1832
1833    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
1834    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
1835        h_tm1, training, count=3)
1836
1837    if self.use_bias:
1838      if not self.reset_after:
1839        input_bias, recurrent_bias = self.bias, None
1840      else:
1841        input_bias, recurrent_bias = array_ops.unstack(self.bias)
1842
1843    if self.implementation == 1:
1844      if 0. < self.dropout < 1.:
1845        inputs_z = inputs * dp_mask[0]
1846        inputs_r = inputs * dp_mask[1]
1847        inputs_h = inputs * dp_mask[2]
1848      else:
1849        inputs_z = inputs
1850        inputs_r = inputs
1851        inputs_h = inputs
1852
1853      x_z = K.dot(inputs_z, self.kernel[:, :self.units])
1854      x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2])
1855      x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:])
1856
1857      if self.use_bias:
1858        x_z = K.bias_add(x_z, input_bias[:self.units])
1859        x_r = K.bias_add(x_r, input_bias[self.units: self.units * 2])
1860        x_h = K.bias_add(x_h, input_bias[self.units * 2:])
1861
1862      if 0. < self.recurrent_dropout < 1.:
1863        h_tm1_z = h_tm1 * rec_dp_mask[0]
1864        h_tm1_r = h_tm1 * rec_dp_mask[1]
1865        h_tm1_h = h_tm1 * rec_dp_mask[2]
1866      else:
1867        h_tm1_z = h_tm1
1868        h_tm1_r = h_tm1
1869        h_tm1_h = h_tm1
1870
1871      recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])
1872      recurrent_r = K.dot(h_tm1_r,
1873                          self.recurrent_kernel[:, self.units:self.units * 2])
1874      if self.reset_after and self.use_bias:
1875        recurrent_z = K.bias_add(recurrent_z, recurrent_bias[:self.units])
1876        recurrent_r = K.bias_add(recurrent_r,
1877                                 recurrent_bias[self.units:self.units * 2])
1878
1879      z = self.recurrent_activation(x_z + recurrent_z)
1880      r = self.recurrent_activation(x_r + recurrent_r)
1881
1882      # reset gate applied after/before matrix multiplication
1883      if self.reset_after:
1884        recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
1885        if self.use_bias:
1886          recurrent_h = K.bias_add(recurrent_h, recurrent_bias[self.units * 2:])
1887        recurrent_h = r * recurrent_h
1888      else:
1889        recurrent_h = K.dot(r * h_tm1_h,
1890                            self.recurrent_kernel[:, self.units * 2:])
1891
1892      hh = self.activation(x_h + recurrent_h)
1893    else:
1894      if 0. < self.dropout < 1.:
1895        inputs = inputs * dp_mask[0]
1896
1897      # inputs projected by all gate matrices at once
1898      matrix_x = K.dot(inputs, self.kernel)
1899      if self.use_bias:
1900        # biases: bias_z_i, bias_r_i, bias_h_i
1901        matrix_x = K.bias_add(matrix_x, input_bias)
1902
1903      x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=-1)
1904
1905      if self.reset_after:
1906        # hidden state projected by all gate matrices at once
1907        matrix_inner = K.dot(h_tm1, self.recurrent_kernel)
1908        if self.use_bias:
1909          matrix_inner = K.bias_add(matrix_inner, recurrent_bias)
1910      else:
1911        # hidden state projected separately for update/reset and new
1912        matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
1913
1914      recurrent_z, recurrent_r, recurrent_h = array_ops.split(
1915          matrix_inner, [self.units, self.units, -1], axis=-1)
1916
1917      z = self.recurrent_activation(x_z + recurrent_z)
1918      r = self.recurrent_activation(x_r + recurrent_r)
1919
1920      if self.reset_after:
1921        recurrent_h = r * recurrent_h
1922      else:
1923        recurrent_h = K.dot(r * h_tm1,
1924                            self.recurrent_kernel[:, 2 * self.units:])
1925
1926      hh = self.activation(x_h + recurrent_h)
1927    # previous and candidate state mixed by update gate
1928    h = z * h_tm1 + (1 - z) * hh
1929    new_state = [h] if nest.is_nested(states) else h
1930    return h, new_state
1931
1932  def get_config(self):
1933    config = {
1934        'units': self.units,
1935        'activation': activations.serialize(self.activation),
1936        'recurrent_activation':
1937            activations.serialize(self.recurrent_activation),
1938        'use_bias': self.use_bias,
1939        'kernel_initializer': initializers.serialize(self.kernel_initializer),
1940        'recurrent_initializer':
1941            initializers.serialize(self.recurrent_initializer),
1942        'bias_initializer': initializers.serialize(self.bias_initializer),
1943        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
1944        'recurrent_regularizer':
1945            regularizers.serialize(self.recurrent_regularizer),
1946        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
1947        'kernel_constraint': constraints.serialize(self.kernel_constraint),
1948        'recurrent_constraint':
1949            constraints.serialize(self.recurrent_constraint),
1950        'bias_constraint': constraints.serialize(self.bias_constraint),
1951        'dropout': self.dropout,
1952        'recurrent_dropout': self.recurrent_dropout,
1953        'implementation': self.implementation,
1954        'reset_after': self.reset_after
1955    }
1956    config.update(_config_for_enable_caching_device(self))
1957    base_config = super(GRUCell, self).get_config()
1958    return dict(list(base_config.items()) + list(config.items()))
1959
1960  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1961    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1962
1963
1964@keras_export(v1=['keras.layers.GRU'])
1965class GRU(RNN):
1966  """Gated Recurrent Unit - Cho et al. 2014.
1967
1968  There are two variants. The default one is based on 1406.1078v3 and
1969  has reset gate applied to hidden state before matrix multiplication. The
1970  other one is based on original 1406.1078v1 and has the order reversed.
1971
1972  The second variant is compatible with CuDNNGRU (GPU-only) and allows
1973  inference on CPU. Thus it has separate biases for `kernel` and
1974  `recurrent_kernel`. Use `'reset_after'=True` and
1975  `recurrent_activation='sigmoid'`.
1976
1977  Args:
1978    units: Positive integer, dimensionality of the output space.
1979    activation: Activation function to use.
1980      Default: hyperbolic tangent (`tanh`).
1981      If you pass `None`, no activation is applied
1982      (ie. "linear" activation: `a(x) = x`).
1983    recurrent_activation: Activation function to use
1984      for the recurrent step.
1985      Default: hard sigmoid (`hard_sigmoid`).
1986      If you pass `None`, no activation is applied
1987      (ie. "linear" activation: `a(x) = x`).
1988    use_bias: Boolean, whether the layer uses a bias vector.
1989    kernel_initializer: Initializer for the `kernel` weights matrix,
1990      used for the linear transformation of the inputs.
1991    recurrent_initializer: Initializer for the `recurrent_kernel`
1992      weights matrix, used for the linear transformation of the recurrent state.
1993    bias_initializer: Initializer for the bias vector.
1994    kernel_regularizer: Regularizer function applied to
1995      the `kernel` weights matrix.
1996    recurrent_regularizer: Regularizer function applied to
1997      the `recurrent_kernel` weights matrix.
1998    bias_regularizer: Regularizer function applied to the bias vector.
1999    activity_regularizer: Regularizer function applied to
2000      the output of the layer (its "activation")..
2001    kernel_constraint: Constraint function applied to
2002      the `kernel` weights matrix.
2003    recurrent_constraint: Constraint function applied to
2004      the `recurrent_kernel` weights matrix.
2005    bias_constraint: Constraint function applied to the bias vector.
2006    dropout: Float between 0 and 1.
2007      Fraction of the units to drop for
2008      the linear transformation of the inputs.
2009    recurrent_dropout: Float between 0 and 1.
2010      Fraction of the units to drop for
2011      the linear transformation of the recurrent state.
2012    return_sequences: Boolean. Whether to return the last output
2013      in the output sequence, or the full sequence.
2014    return_state: Boolean. Whether to return the last state
2015      in addition to the output.
2016    go_backwards: Boolean (default False).
2017      If True, process the input sequence backwards and return the
2018      reversed sequence.
2019    stateful: Boolean (default False). If True, the last state
2020      for each sample at index i in a batch will be used as initial
2021      state for the sample of index i in the following batch.
2022    unroll: Boolean (default False).
2023      If True, the network will be unrolled,
2024      else a symbolic loop will be used.
2025      Unrolling can speed-up a RNN,
2026      although it tends to be more memory-intensive.
2027      Unrolling is only suitable for short sequences.
2028    time_major: The shape format of the `inputs` and `outputs` tensors.
2029      If True, the inputs and outputs will be in shape
2030      `(timesteps, batch, ...)`, whereas in the False case, it will be
2031      `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
2032      efficient because it avoids transposes at the beginning and end of the
2033      RNN calculation. However, most TensorFlow data is batch-major, so by
2034      default this function accepts input and emits output in batch-major
2035      form.
2036    reset_after: GRU convention (whether to apply reset gate after or
2037      before matrix multiplication). False = "before" (default),
2038      True = "after" (CuDNN compatible).
2039
2040  Call arguments:
2041    inputs: A 3D tensor.
2042    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
2043      a given timestep should be masked. An individual `True` entry indicates
2044      that the corresponding timestep should be utilized, while a `False`
2045      entry indicates that the corresponding timestep should be ignored.
2046    training: Python boolean indicating whether the layer should behave in
2047      training mode or in inference mode. This argument is passed to the cell
2048      when calling it. This is only relevant if `dropout` or
2049      `recurrent_dropout` is used.
2050    initial_state: List of initial state tensors to be passed to the first
2051      call of the cell.
2052  """
2053
2054  def __init__(self,
2055               units,
2056               activation='tanh',
2057               recurrent_activation='hard_sigmoid',
2058               use_bias=True,
2059               kernel_initializer='glorot_uniform',
2060               recurrent_initializer='orthogonal',
2061               bias_initializer='zeros',
2062               kernel_regularizer=None,
2063               recurrent_regularizer=None,
2064               bias_regularizer=None,
2065               activity_regularizer=None,
2066               kernel_constraint=None,
2067               recurrent_constraint=None,
2068               bias_constraint=None,
2069               dropout=0.,
2070               recurrent_dropout=0.,
2071               return_sequences=False,
2072               return_state=False,
2073               go_backwards=False,
2074               stateful=False,
2075               unroll=False,
2076               reset_after=False,
2077               **kwargs):
2078    implementation = kwargs.pop('implementation', 1)
2079    if implementation == 0:
2080      logging.warning('`implementation=0` has been deprecated, '
2081                      'and now defaults to `implementation=1`.'
2082                      'Please update your layer call.')
2083    if 'enable_caching_device' in kwargs:
2084      cell_kwargs = {'enable_caching_device':
2085                     kwargs.pop('enable_caching_device')}
2086    else:
2087      cell_kwargs = {}
2088    cell = GRUCell(
2089        units,
2090        activation=activation,
2091        recurrent_activation=recurrent_activation,
2092        use_bias=use_bias,
2093        kernel_initializer=kernel_initializer,
2094        recurrent_initializer=recurrent_initializer,
2095        bias_initializer=bias_initializer,
2096        kernel_regularizer=kernel_regularizer,
2097        recurrent_regularizer=recurrent_regularizer,
2098        bias_regularizer=bias_regularizer,
2099        kernel_constraint=kernel_constraint,
2100        recurrent_constraint=recurrent_constraint,
2101        bias_constraint=bias_constraint,
2102        dropout=dropout,
2103        recurrent_dropout=recurrent_dropout,
2104        implementation=implementation,
2105        reset_after=reset_after,
2106        dtype=kwargs.get('dtype'),
2107        trainable=kwargs.get('trainable', True),
2108        **cell_kwargs)
2109    super(GRU, self).__init__(
2110        cell,
2111        return_sequences=return_sequences,
2112        return_state=return_state,
2113        go_backwards=go_backwards,
2114        stateful=stateful,
2115        unroll=unroll,
2116        **kwargs)
2117    self.activity_regularizer = regularizers.get(activity_regularizer)
2118    self.input_spec = [InputSpec(ndim=3)]
2119
2120  def call(self, inputs, mask=None, training=None, initial_state=None):
2121    return super(GRU, self).call(
2122        inputs, mask=mask, training=training, initial_state=initial_state)
2123
2124  @property
2125  def units(self):
2126    return self.cell.units
2127
2128  @property
2129  def activation(self):
2130    return self.cell.activation
2131
2132  @property
2133  def recurrent_activation(self):
2134    return self.cell.recurrent_activation
2135
2136  @property
2137  def use_bias(self):
2138    return self.cell.use_bias
2139
2140  @property
2141  def kernel_initializer(self):
2142    return self.cell.kernel_initializer
2143
2144  @property
2145  def recurrent_initializer(self):
2146    return self.cell.recurrent_initializer
2147
2148  @property
2149  def bias_initializer(self):
2150    return self.cell.bias_initializer
2151
2152  @property
2153  def kernel_regularizer(self):
2154    return self.cell.kernel_regularizer
2155
2156  @property
2157  def recurrent_regularizer(self):
2158    return self.cell.recurrent_regularizer
2159
2160  @property
2161  def bias_regularizer(self):
2162    return self.cell.bias_regularizer
2163
2164  @property
2165  def kernel_constraint(self):
2166    return self.cell.kernel_constraint
2167
2168  @property
2169  def recurrent_constraint(self):
2170    return self.cell.recurrent_constraint
2171
2172  @property
2173  def bias_constraint(self):
2174    return self.cell.bias_constraint
2175
2176  @property
2177  def dropout(self):
2178    return self.cell.dropout
2179
2180  @property
2181  def recurrent_dropout(self):
2182    return self.cell.recurrent_dropout
2183
2184  @property
2185  def implementation(self):
2186    return self.cell.implementation
2187
2188  @property
2189  def reset_after(self):
2190    return self.cell.reset_after
2191
2192  def get_config(self):
2193    config = {
2194        'units':
2195            self.units,
2196        'activation':
2197            activations.serialize(self.activation),
2198        'recurrent_activation':
2199            activations.serialize(self.recurrent_activation),
2200        'use_bias':
2201            self.use_bias,
2202        'kernel_initializer':
2203            initializers.serialize(self.kernel_initializer),
2204        'recurrent_initializer':
2205            initializers.serialize(self.recurrent_initializer),
2206        'bias_initializer':
2207            initializers.serialize(self.bias_initializer),
2208        'kernel_regularizer':
2209            regularizers.serialize(self.kernel_regularizer),
2210        'recurrent_regularizer':
2211            regularizers.serialize(self.recurrent_regularizer),
2212        'bias_regularizer':
2213            regularizers.serialize(self.bias_regularizer),
2214        'activity_regularizer':
2215            regularizers.serialize(self.activity_regularizer),
2216        'kernel_constraint':
2217            constraints.serialize(self.kernel_constraint),
2218        'recurrent_constraint':
2219            constraints.serialize(self.recurrent_constraint),
2220        'bias_constraint':
2221            constraints.serialize(self.bias_constraint),
2222        'dropout':
2223            self.dropout,
2224        'recurrent_dropout':
2225            self.recurrent_dropout,
2226        'implementation':
2227            self.implementation,
2228        'reset_after':
2229            self.reset_after
2230    }
2231    config.update(_config_for_enable_caching_device(self.cell))
2232    base_config = super(GRU, self).get_config()
2233    del base_config['cell']
2234    return dict(list(base_config.items()) + list(config.items()))
2235
2236  @classmethod
2237  def from_config(cls, config):
2238    if 'implementation' in config and config['implementation'] == 0:
2239      config['implementation'] = 1
2240    return cls(**config)
2241
2242
2243@keras_export(v1=['keras.layers.LSTMCell'])
2244class LSTMCell(DropoutRNNCellMixin, Layer):
2245  """Cell class for the LSTM layer.
2246
2247  Args:
2248    units: Positive integer, dimensionality of the output space.
2249    activation: Activation function to use.
2250      Default: hyperbolic tangent (`tanh`).
2251      If you pass `None`, no activation is applied
2252      (ie. "linear" activation: `a(x) = x`).
2253    recurrent_activation: Activation function to use
2254      for the recurrent step.
2255      Default: hard sigmoid (`hard_sigmoid`).
2256      If you pass `None`, no activation is applied
2257      (ie. "linear" activation: `a(x) = x`).
2258    use_bias: Boolean, whether the layer uses a bias vector.
2259    kernel_initializer: Initializer for the `kernel` weights matrix,
2260      used for the linear transformation of the inputs.
2261    recurrent_initializer: Initializer for the `recurrent_kernel`
2262      weights matrix,
2263      used for the linear transformation of the recurrent state.
2264    bias_initializer: Initializer for the bias vector.
2265    unit_forget_bias: Boolean.
2266      If True, add 1 to the bias of the forget gate at initialization.
2267      Setting it to true will also force `bias_initializer="zeros"`.
2268      This is recommended in [Jozefowicz et al., 2015](
2269        http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
2270    kernel_regularizer: Regularizer function applied to
2271      the `kernel` weights matrix.
2272    recurrent_regularizer: Regularizer function applied to
2273      the `recurrent_kernel` weights matrix.
2274    bias_regularizer: Regularizer function applied to the bias vector.
2275    kernel_constraint: Constraint function applied to
2276      the `kernel` weights matrix.
2277    recurrent_constraint: Constraint function applied to
2278      the `recurrent_kernel` weights matrix.
2279    bias_constraint: Constraint function applied to the bias vector.
2280    dropout: Float between 0 and 1.
2281      Fraction of the units to drop for
2282      the linear transformation of the inputs.
2283    recurrent_dropout: Float between 0 and 1.
2284      Fraction of the units to drop for
2285      the linear transformation of the recurrent state.
2286
2287  Call arguments:
2288    inputs: A 2D tensor.
2289    states: List of state tensors corresponding to the previous timestep.
2290    training: Python boolean indicating whether the layer should behave in
2291      training mode or in inference mode. Only relevant when `dropout` or
2292      `recurrent_dropout` is used.
2293  """
2294
2295  def __init__(self,
2296               units,
2297               activation='tanh',
2298               recurrent_activation='hard_sigmoid',
2299               use_bias=True,
2300               kernel_initializer='glorot_uniform',
2301               recurrent_initializer='orthogonal',
2302               bias_initializer='zeros',
2303               unit_forget_bias=True,
2304               kernel_regularizer=None,
2305               recurrent_regularizer=None,
2306               bias_regularizer=None,
2307               kernel_constraint=None,
2308               recurrent_constraint=None,
2309               bias_constraint=None,
2310               dropout=0.,
2311               recurrent_dropout=0.,
2312               **kwargs):
2313    # By default use cached variable under v2 mode, see b/143699808.
2314    if ops.executing_eagerly_outside_functions():
2315      self._enable_caching_device = kwargs.pop('enable_caching_device', True)
2316    else:
2317      self._enable_caching_device = kwargs.pop('enable_caching_device', False)
2318    super(LSTMCell, self).__init__(**kwargs)
2319    self.units = units
2320    self.activation = activations.get(activation)
2321    self.recurrent_activation = activations.get(recurrent_activation)
2322    self.use_bias = use_bias
2323
2324    self.kernel_initializer = initializers.get(kernel_initializer)
2325    self.recurrent_initializer = initializers.get(recurrent_initializer)
2326    self.bias_initializer = initializers.get(bias_initializer)
2327    self.unit_forget_bias = unit_forget_bias
2328
2329    self.kernel_regularizer = regularizers.get(kernel_regularizer)
2330    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
2331    self.bias_regularizer = regularizers.get(bias_regularizer)
2332
2333    self.kernel_constraint = constraints.get(kernel_constraint)
2334    self.recurrent_constraint = constraints.get(recurrent_constraint)
2335    self.bias_constraint = constraints.get(bias_constraint)
2336
2337    self.dropout = min(1., max(0., dropout))
2338    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
2339    implementation = kwargs.pop('implementation', 1)
2340    if self.recurrent_dropout != 0 and implementation != 1:
2341      logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
2342      self.implementation = 1
2343    else:
2344      self.implementation = implementation
2345    # tuple(_ListWrapper) was silently dropping list content in at least 2.7.10,
2346    # and fixed after 2.7.16. Converting the state_size to wrapper around
2347    # NoDependency(), so that the base_layer.__setattr__ will not convert it to
2348    # ListWrapper. Down the stream, self.states will be a list since it is
2349    # generated from nest.map_structure with list, and tuple(list) will work
2350    # properly.
2351    self.state_size = data_structures.NoDependency([self.units, self.units])
2352    self.output_size = self.units
2353
2354  @tf_utils.shape_type_conversion
2355  def build(self, input_shape):
2356    default_caching_device = _caching_device(self)
2357    input_dim = input_shape[-1]
2358    self.kernel = self.add_weight(
2359        shape=(input_dim, self.units * 4),
2360        name='kernel',
2361        initializer=self.kernel_initializer,
2362        regularizer=self.kernel_regularizer,
2363        constraint=self.kernel_constraint,
2364        caching_device=default_caching_device)
2365    self.recurrent_kernel = self.add_weight(
2366        shape=(self.units, self.units * 4),
2367        name='recurrent_kernel',
2368        initializer=self.recurrent_initializer,
2369        regularizer=self.recurrent_regularizer,
2370        constraint=self.recurrent_constraint,
2371        caching_device=default_caching_device)
2372
2373    if self.use_bias:
2374      if self.unit_forget_bias:
2375
2376        def bias_initializer(_, *args, **kwargs):
2377          return K.concatenate([
2378              self.bias_initializer((self.units,), *args, **kwargs),
2379              initializers.get('ones')((self.units,), *args, **kwargs),
2380              self.bias_initializer((self.units * 2,), *args, **kwargs),
2381          ])
2382      else:
2383        bias_initializer = self.bias_initializer
2384      self.bias = self.add_weight(
2385          shape=(self.units * 4,),
2386          name='bias',
2387          initializer=bias_initializer,
2388          regularizer=self.bias_regularizer,
2389          constraint=self.bias_constraint,
2390          caching_device=default_caching_device)
2391    else:
2392      self.bias = None
2393    self.built = True
2394
2395  def _compute_carry_and_output(self, x, h_tm1, c_tm1):
2396    """Computes carry and output using split kernels."""
2397    x_i, x_f, x_c, x_o = x
2398    h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
2399    i = self.recurrent_activation(
2400        x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
2401    f = self.recurrent_activation(x_f + K.dot(
2402        h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))
2403    c = f * c_tm1 + i * self.activation(x_c + K.dot(
2404        h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
2405    o = self.recurrent_activation(
2406        x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
2407    return c, o
2408
2409  def _compute_carry_and_output_fused(self, z, c_tm1):
2410    """Computes carry and output using fused kernels."""
2411    z0, z1, z2, z3 = z
2412    i = self.recurrent_activation(z0)
2413    f = self.recurrent_activation(z1)
2414    c = f * c_tm1 + i * self.activation(z2)
2415    o = self.recurrent_activation(z3)
2416    return c, o
2417
2418  def call(self, inputs, states, training=None):
2419    h_tm1 = states[0]  # previous memory state
2420    c_tm1 = states[1]  # previous carry state
2421
2422    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
2423    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
2424        h_tm1, training, count=4)
2425
2426    if self.implementation == 1:
2427      if 0 < self.dropout < 1.:
2428        inputs_i = inputs * dp_mask[0]
2429        inputs_f = inputs * dp_mask[1]
2430        inputs_c = inputs * dp_mask[2]
2431        inputs_o = inputs * dp_mask[3]
2432      else:
2433        inputs_i = inputs
2434        inputs_f = inputs
2435        inputs_c = inputs
2436        inputs_o = inputs
2437      k_i, k_f, k_c, k_o = array_ops.split(
2438          self.kernel, num_or_size_splits=4, axis=1)
2439      x_i = K.dot(inputs_i, k_i)
2440      x_f = K.dot(inputs_f, k_f)
2441      x_c = K.dot(inputs_c, k_c)
2442      x_o = K.dot(inputs_o, k_o)
2443      if self.use_bias:
2444        b_i, b_f, b_c, b_o = array_ops.split(
2445            self.bias, num_or_size_splits=4, axis=0)
2446        x_i = K.bias_add(x_i, b_i)
2447        x_f = K.bias_add(x_f, b_f)
2448        x_c = K.bias_add(x_c, b_c)
2449        x_o = K.bias_add(x_o, b_o)
2450
2451      if 0 < self.recurrent_dropout < 1.:
2452        h_tm1_i = h_tm1 * rec_dp_mask[0]
2453        h_tm1_f = h_tm1 * rec_dp_mask[1]
2454        h_tm1_c = h_tm1 * rec_dp_mask[2]
2455        h_tm1_o = h_tm1 * rec_dp_mask[3]
2456      else:
2457        h_tm1_i = h_tm1
2458        h_tm1_f = h_tm1
2459        h_tm1_c = h_tm1
2460        h_tm1_o = h_tm1
2461      x = (x_i, x_f, x_c, x_o)
2462      h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
2463      c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
2464    else:
2465      if 0. < self.dropout < 1.:
2466        inputs = inputs * dp_mask[0]
2467      z = K.dot(inputs, self.kernel)
2468      z += K.dot(h_tm1, self.recurrent_kernel)
2469      if self.use_bias:
2470        z = K.bias_add(z, self.bias)
2471
2472      z = array_ops.split(z, num_or_size_splits=4, axis=1)
2473      c, o = self._compute_carry_and_output_fused(z, c_tm1)
2474
2475    h = o * self.activation(c)
2476    return h, [h, c]
2477
2478  def get_config(self):
2479    config = {
2480        'units':
2481            self.units,
2482        'activation':
2483            activations.serialize(self.activation),
2484        'recurrent_activation':
2485            activations.serialize(self.recurrent_activation),
2486        'use_bias':
2487            self.use_bias,
2488        'kernel_initializer':
2489            initializers.serialize(self.kernel_initializer),
2490        'recurrent_initializer':
2491            initializers.serialize(self.recurrent_initializer),
2492        'bias_initializer':
2493            initializers.serialize(self.bias_initializer),
2494        'unit_forget_bias':
2495            self.unit_forget_bias,
2496        'kernel_regularizer':
2497            regularizers.serialize(self.kernel_regularizer),
2498        'recurrent_regularizer':
2499            regularizers.serialize(self.recurrent_regularizer),
2500        'bias_regularizer':
2501            regularizers.serialize(self.bias_regularizer),
2502        'kernel_constraint':
2503            constraints.serialize(self.kernel_constraint),
2504        'recurrent_constraint':
2505            constraints.serialize(self.recurrent_constraint),
2506        'bias_constraint':
2507            constraints.serialize(self.bias_constraint),
2508        'dropout':
2509            self.dropout,
2510        'recurrent_dropout':
2511            self.recurrent_dropout,
2512        'implementation':
2513            self.implementation
2514    }
2515    config.update(_config_for_enable_caching_device(self))
2516    base_config = super(LSTMCell, self).get_config()
2517    return dict(list(base_config.items()) + list(config.items()))
2518
2519  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
2520    return list(_generate_zero_filled_state_for_cell(
2521        self, inputs, batch_size, dtype))
2522
2523
2524@keras_export('keras.experimental.PeepholeLSTMCell')
2525class PeepholeLSTMCell(LSTMCell):
2526  """Equivalent to LSTMCell class but adds peephole connections.
2527
2528  Peephole connections allow the gates to utilize the previous internal state as
2529  well as the previous hidden state (which is what LSTMCell is limited to).
2530  This allows PeepholeLSTMCell to better learn precise timings over LSTMCell.
2531
2532  From [Gers et al., 2002](
2533    http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf):
2534
2535  "We find that LSTM augmented by 'peephole connections' from its internal
2536  cells to its multiplicative gates can learn the fine distinction between
2537  sequences of spikes spaced either 50 or 49 time steps apart without the help
2538  of any short training exemplars."
2539
2540  The peephole implementation is based on:
2541
2542  [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf)
2543
2544  Example:
2545
2546  ```python
2547  # Create 2 PeepholeLSTMCells
2548  peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]]
2549  # Create a layer composed sequentially of the peephole LSTM cells.
2550  layer = RNN(peephole_lstm_cells)
2551  input = keras.Input((timesteps, input_dim))
2552  output = layer(input)
2553  ```
2554  """
2555
2556  def __init__(self,
2557               units,
2558               activation='tanh',
2559               recurrent_activation='hard_sigmoid',
2560               use_bias=True,
2561               kernel_initializer='glorot_uniform',
2562               recurrent_initializer='orthogonal',
2563               bias_initializer='zeros',
2564               unit_forget_bias=True,
2565               kernel_regularizer=None,
2566               recurrent_regularizer=None,
2567               bias_regularizer=None,
2568               kernel_constraint=None,
2569               recurrent_constraint=None,
2570               bias_constraint=None,
2571               dropout=0.,
2572               recurrent_dropout=0.,
2573               **kwargs):
2574    warnings.warn('`tf.keras.experimental.PeepholeLSTMCell` is deprecated '
2575                  'and will be removed in a future version. '
2576                  'Please use tensorflow_addons.rnn.PeepholeLSTMCell '
2577                  'instead.')
2578    super(PeepholeLSTMCell, self).__init__(
2579        units=units,
2580        activation=activation,
2581        recurrent_activation=recurrent_activation,
2582        use_bias=use_bias,
2583        kernel_initializer=kernel_initializer,
2584        recurrent_initializer=recurrent_initializer,
2585        bias_initializer=bias_initializer,
2586        unit_forget_bias=unit_forget_bias,
2587        kernel_regularizer=kernel_regularizer,
2588        recurrent_regularizer=recurrent_regularizer,
2589        bias_regularizer=bias_regularizer,
2590        kernel_constraint=kernel_constraint,
2591        recurrent_constraint=recurrent_constraint,
2592        bias_constraint=bias_constraint,
2593        dropout=dropout,
2594        recurrent_dropout=recurrent_dropout,
2595        implementation=kwargs.pop('implementation', 1),
2596        **kwargs)
2597
2598  def build(self, input_shape):
2599    super(PeepholeLSTMCell, self).build(input_shape)
2600    # The following are the weight matrices for the peephole connections. These
2601    # are multiplied with the previous internal state during the computation of
2602    # carry and output.
2603    self.input_gate_peephole_weights = self.add_weight(
2604        shape=(self.units,),
2605        name='input_gate_peephole_weights',
2606        initializer=self.kernel_initializer)
2607    self.forget_gate_peephole_weights = self.add_weight(
2608        shape=(self.units,),
2609        name='forget_gate_peephole_weights',
2610        initializer=self.kernel_initializer)
2611    self.output_gate_peephole_weights = self.add_weight(
2612        shape=(self.units,),
2613        name='output_gate_peephole_weights',
2614        initializer=self.kernel_initializer)
2615
2616  def _compute_carry_and_output(self, x, h_tm1, c_tm1):
2617    x_i, x_f, x_c, x_o = x
2618    h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
2619    i = self.recurrent_activation(
2620        x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) +
2621        self.input_gate_peephole_weights * c_tm1)
2622    f = self.recurrent_activation(x_f + K.dot(
2623        h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) +
2624                                  self.forget_gate_peephole_weights * c_tm1)
2625    c = f * c_tm1 + i * self.activation(x_c + K.dot(
2626        h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
2627    o = self.recurrent_activation(
2628        x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) +
2629        self.output_gate_peephole_weights * c)
2630    return c, o
2631
2632  def _compute_carry_and_output_fused(self, z, c_tm1):
2633    z0, z1, z2, z3 = z
2634    i = self.recurrent_activation(z0 +
2635                                  self.input_gate_peephole_weights * c_tm1)
2636    f = self.recurrent_activation(z1 +
2637                                  self.forget_gate_peephole_weights * c_tm1)
2638    c = f * c_tm1 + i * self.activation(z2)
2639    o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c)
2640    return c, o
2641
2642
2643@keras_export(v1=['keras.layers.LSTM'])
2644class LSTM(RNN):
2645  """Long Short-Term Memory layer - Hochreiter 1997.
2646
2647   Note that this cell is not optimized for performance on GPU. Please use
2648  `tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU.
2649
2650  Args:
2651    units: Positive integer, dimensionality of the output space.
2652    activation: Activation function to use.
2653      Default: hyperbolic tangent (`tanh`).
2654      If you pass `None`, no activation is applied
2655      (ie. "linear" activation: `a(x) = x`).
2656    recurrent_activation: Activation function to use
2657      for the recurrent step.
2658      Default: hard sigmoid (`hard_sigmoid`).
2659      If you pass `None`, no activation is applied
2660      (ie. "linear" activation: `a(x) = x`).
2661    use_bias: Boolean, whether the layer uses a bias vector.
2662    kernel_initializer: Initializer for the `kernel` weights matrix,
2663      used for the linear transformation of the inputs..
2664    recurrent_initializer: Initializer for the `recurrent_kernel`
2665      weights matrix,
2666      used for the linear transformation of the recurrent state.
2667    bias_initializer: Initializer for the bias vector.
2668    unit_forget_bias: Boolean.
2669      If True, add 1 to the bias of the forget gate at initialization.
2670      Setting it to true will also force `bias_initializer="zeros"`.
2671      This is recommended in [Jozefowicz et al., 2015](
2672        http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
2673    kernel_regularizer: Regularizer function applied to
2674      the `kernel` weights matrix.
2675    recurrent_regularizer: Regularizer function applied to
2676      the `recurrent_kernel` weights matrix.
2677    bias_regularizer: Regularizer function applied to the bias vector.
2678    activity_regularizer: Regularizer function applied to
2679      the output of the layer (its "activation").
2680    kernel_constraint: Constraint function applied to
2681      the `kernel` weights matrix.
2682    recurrent_constraint: Constraint function applied to
2683      the `recurrent_kernel` weights matrix.
2684    bias_constraint: Constraint function applied to the bias vector.
2685    dropout: Float between 0 and 1.
2686      Fraction of the units to drop for
2687      the linear transformation of the inputs.
2688    recurrent_dropout: Float between 0 and 1.
2689      Fraction of the units to drop for
2690      the linear transformation of the recurrent state.
2691    return_sequences: Boolean. Whether to return the last output.
2692      in the output sequence, or the full sequence.
2693    return_state: Boolean. Whether to return the last state
2694      in addition to the output.
2695    go_backwards: Boolean (default False).
2696      If True, process the input sequence backwards and return the
2697      reversed sequence.
2698    stateful: Boolean (default False). If True, the last state
2699      for each sample at index i in a batch will be used as initial
2700      state for the sample of index i in the following batch.
2701    unroll: Boolean (default False).
2702      If True, the network will be unrolled,
2703      else a symbolic loop will be used.
2704      Unrolling can speed-up a RNN,
2705      although it tends to be more memory-intensive.
2706      Unrolling is only suitable for short sequences.
2707    time_major: The shape format of the `inputs` and `outputs` tensors.
2708      If True, the inputs and outputs will be in shape
2709      `(timesteps, batch, ...)`, whereas in the False case, it will be
2710      `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
2711      efficient because it avoids transposes at the beginning and end of the
2712      RNN calculation. However, most TensorFlow data is batch-major, so by
2713      default this function accepts input and emits output in batch-major
2714      form.
2715
2716  Call arguments:
2717    inputs: A 3D tensor.
2718    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
2719      a given timestep should be masked. An individual `True` entry indicates
2720      that the corresponding timestep should be utilized, while a `False`
2721      entry indicates that the corresponding timestep should be ignored.
2722    training: Python boolean indicating whether the layer should behave in
2723      training mode or in inference mode. This argument is passed to the cell
2724      when calling it. This is only relevant if `dropout` or
2725      `recurrent_dropout` is used.
2726    initial_state: List of initial state tensors to be passed to the first
2727      call of the cell.
2728  """
2729
2730  def __init__(self,
2731               units,
2732               activation='tanh',
2733               recurrent_activation='hard_sigmoid',
2734               use_bias=True,
2735               kernel_initializer='glorot_uniform',
2736               recurrent_initializer='orthogonal',
2737               bias_initializer='zeros',
2738               unit_forget_bias=True,
2739               kernel_regularizer=None,
2740               recurrent_regularizer=None,
2741               bias_regularizer=None,
2742               activity_regularizer=None,
2743               kernel_constraint=None,
2744               recurrent_constraint=None,
2745               bias_constraint=None,
2746               dropout=0.,
2747               recurrent_dropout=0.,
2748               return_sequences=False,
2749               return_state=False,
2750               go_backwards=False,
2751               stateful=False,
2752               unroll=False,
2753               **kwargs):
2754    implementation = kwargs.pop('implementation', 1)
2755    if implementation == 0:
2756      logging.warning('`implementation=0` has been deprecated, '
2757                      'and now defaults to `implementation=1`.'
2758                      'Please update your layer call.')
2759    if 'enable_caching_device' in kwargs:
2760      cell_kwargs = {'enable_caching_device':
2761                     kwargs.pop('enable_caching_device')}
2762    else:
2763      cell_kwargs = {}
2764    cell = LSTMCell(
2765        units,
2766        activation=activation,
2767        recurrent_activation=recurrent_activation,
2768        use_bias=use_bias,
2769        kernel_initializer=kernel_initializer,
2770        recurrent_initializer=recurrent_initializer,
2771        unit_forget_bias=unit_forget_bias,
2772        bias_initializer=bias_initializer,
2773        kernel_regularizer=kernel_regularizer,
2774        recurrent_regularizer=recurrent_regularizer,
2775        bias_regularizer=bias_regularizer,
2776        kernel_constraint=kernel_constraint,
2777        recurrent_constraint=recurrent_constraint,
2778        bias_constraint=bias_constraint,
2779        dropout=dropout,
2780        recurrent_dropout=recurrent_dropout,
2781        implementation=implementation,
2782        dtype=kwargs.get('dtype'),
2783        trainable=kwargs.get('trainable', True),
2784        **cell_kwargs)
2785    super(LSTM, self).__init__(
2786        cell,
2787        return_sequences=return_sequences,
2788        return_state=return_state,
2789        go_backwards=go_backwards,
2790        stateful=stateful,
2791        unroll=unroll,
2792        **kwargs)
2793    self.activity_regularizer = regularizers.get(activity_regularizer)
2794    self.input_spec = [InputSpec(ndim=3)]
2795
2796  def call(self, inputs, mask=None, training=None, initial_state=None):
2797    return super(LSTM, self).call(
2798        inputs, mask=mask, training=training, initial_state=initial_state)
2799
2800  @property
2801  def units(self):
2802    return self.cell.units
2803
2804  @property
2805  def activation(self):
2806    return self.cell.activation
2807
2808  @property
2809  def recurrent_activation(self):
2810    return self.cell.recurrent_activation
2811
2812  @property
2813  def use_bias(self):
2814    return self.cell.use_bias
2815
2816  @property
2817  def kernel_initializer(self):
2818    return self.cell.kernel_initializer
2819
2820  @property
2821  def recurrent_initializer(self):
2822    return self.cell.recurrent_initializer
2823
2824  @property
2825  def bias_initializer(self):
2826    return self.cell.bias_initializer
2827
2828  @property
2829  def unit_forget_bias(self):
2830    return self.cell.unit_forget_bias
2831
2832  @property
2833  def kernel_regularizer(self):
2834    return self.cell.kernel_regularizer
2835
2836  @property
2837  def recurrent_regularizer(self):
2838    return self.cell.recurrent_regularizer
2839
2840  @property
2841  def bias_regularizer(self):
2842    return self.cell.bias_regularizer
2843
2844  @property
2845  def kernel_constraint(self):
2846    return self.cell.kernel_constraint
2847
2848  @property
2849  def recurrent_constraint(self):
2850    return self.cell.recurrent_constraint
2851
2852  @property
2853  def bias_constraint(self):
2854    return self.cell.bias_constraint
2855
2856  @property
2857  def dropout(self):
2858    return self.cell.dropout
2859
2860  @property
2861  def recurrent_dropout(self):
2862    return self.cell.recurrent_dropout
2863
2864  @property
2865  def implementation(self):
2866    return self.cell.implementation
2867
2868  def get_config(self):
2869    config = {
2870        'units':
2871            self.units,
2872        'activation':
2873            activations.serialize(self.activation),
2874        'recurrent_activation':
2875            activations.serialize(self.recurrent_activation),
2876        'use_bias':
2877            self.use_bias,
2878        'kernel_initializer':
2879            initializers.serialize(self.kernel_initializer),
2880        'recurrent_initializer':
2881            initializers.serialize(self.recurrent_initializer),
2882        'bias_initializer':
2883            initializers.serialize(self.bias_initializer),
2884        'unit_forget_bias':
2885            self.unit_forget_bias,
2886        'kernel_regularizer':
2887            regularizers.serialize(self.kernel_regularizer),
2888        'recurrent_regularizer':
2889            regularizers.serialize(self.recurrent_regularizer),
2890        'bias_regularizer':
2891            regularizers.serialize(self.bias_regularizer),
2892        'activity_regularizer':
2893            regularizers.serialize(self.activity_regularizer),
2894        'kernel_constraint':
2895            constraints.serialize(self.kernel_constraint),
2896        'recurrent_constraint':
2897            constraints.serialize(self.recurrent_constraint),
2898        'bias_constraint':
2899            constraints.serialize(self.bias_constraint),
2900        'dropout':
2901            self.dropout,
2902        'recurrent_dropout':
2903            self.recurrent_dropout,
2904        'implementation':
2905            self.implementation
2906    }
2907    config.update(_config_for_enable_caching_device(self.cell))
2908    base_config = super(LSTM, self).get_config()
2909    del base_config['cell']
2910    return dict(list(base_config.items()) + list(config.items()))
2911
2912  @classmethod
2913  def from_config(cls, config):
2914    if 'implementation' in config and config['implementation'] == 0:
2915      config['implementation'] = 1
2916    return cls(**config)
2917
2918
2919def _generate_dropout_mask(ones, rate, training=None, count=1):
2920  def dropped_inputs():
2921    return K.dropout(ones, rate)
2922
2923  if count > 1:
2924    return [
2925        K.in_train_phase(dropped_inputs, ones, training=training)
2926        for _ in range(count)
2927    ]
2928  return K.in_train_phase(dropped_inputs, ones, training=training)
2929
2930
2931def _standardize_args(inputs, initial_state, constants, num_constants):
2932  """Standardizes `__call__` to a single list of tensor inputs.
2933
2934  When running a model loaded from a file, the input tensors
2935  `initial_state` and `constants` can be passed to `RNN.__call__()` as part
2936  of `inputs` instead of by the dedicated keyword arguments. This method
2937  makes sure the arguments are separated and that `initial_state` and
2938  `constants` are lists of tensors (or None).
2939
2940  Args:
2941    inputs: Tensor or list/tuple of tensors. which may include constants
2942      and initial states. In that case `num_constant` must be specified.
2943    initial_state: Tensor or list of tensors or None, initial states.
2944    constants: Tensor or list of tensors or None, constant tensors.
2945    num_constants: Expected number of constants (if constants are passed as
2946      part of the `inputs` list.
2947
2948  Returns:
2949    inputs: Single tensor or tuple of tensors.
2950    initial_state: List of tensors or None.
2951    constants: List of tensors or None.
2952  """
2953  if isinstance(inputs, list):
2954    # There are several situations here:
2955    # In the graph mode, __call__ will be only called once. The initial_state
2956    # and constants could be in inputs (from file loading).
2957    # In the eager mode, __call__ will be called twice, once during
2958    # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
2959    # model.fit/train_on_batch/predict with real np data. In the second case,
2960    # the inputs will contain initial_state and constants as eager tensor.
2961    #
2962    # For either case, the real input is the first item in the list, which
2963    # could be a nested structure itself. Then followed by initial_states, which
2964    # could be a list of items, or list of list if the initial_state is complex
2965    # structure, and finally followed by constants which is a flat list.
2966    assert initial_state is None and constants is None
2967    if num_constants:
2968      constants = inputs[-num_constants:]
2969      inputs = inputs[:-num_constants]
2970    if len(inputs) > 1:
2971      initial_state = inputs[1:]
2972      inputs = inputs[:1]
2973
2974    if len(inputs) > 1:
2975      inputs = tuple(inputs)
2976    else:
2977      inputs = inputs[0]
2978
2979  def to_list_or_none(x):
2980    if x is None or isinstance(x, list):
2981      return x
2982    if isinstance(x, tuple):
2983      return list(x)
2984    return [x]
2985
2986  initial_state = to_list_or_none(initial_state)
2987  constants = to_list_or_none(constants)
2988
2989  return inputs, initial_state, constants
2990
2991
2992def _is_multiple_state(state_size):
2993  """Check whether the state_size contains multiple states."""
2994  return (hasattr(state_size, '__len__') and
2995          not isinstance(state_size, tensor_shape.TensorShape))
2996
2997
2998def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
2999  if inputs is not None:
3000    batch_size = array_ops.shape(inputs)[0]
3001    dtype = inputs.dtype
3002  return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
3003
3004
3005def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
3006  """Generate a zero filled tensor with shape [batch_size, state_size]."""
3007  if batch_size_tensor is None or dtype is None:
3008    raise ValueError(
3009        'batch_size and dtype cannot be None while constructing initial state: '
3010        'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))
3011
3012  def create_zeros(unnested_state_size):
3013    flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list()
3014    init_state_size = [batch_size_tensor] + flat_dims
3015    return array_ops.zeros(init_state_size, dtype=dtype)
3016
3017  if nest.is_nested(state_size):
3018    return nest.map_structure(create_zeros, state_size)
3019  else:
3020    return create_zeros(state_size)
3021
3022
3023def _caching_device(rnn_cell):
3024  """Returns the caching device for the RNN variable.
3025
3026  This is useful for distributed training, when variable is not located as same
3027  device as the training worker. By enabling the device cache, this allows
3028  worker to read the variable once and cache locally, rather than read it every
3029  time step from remote when it is needed.
3030
3031  Note that this is assuming the variable that cell needs for each time step is
3032  having the same value in the forward path, and only gets updated in the
3033  backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the
3034  cell body relies on any variable that gets updated every time step, then
3035  caching device will cause it to read the stall value.
3036
3037  Args:
3038    rnn_cell: the rnn cell instance.
3039  """
3040  if context.executing_eagerly():
3041    # caching_device is not supported in eager mode.
3042    return None
3043  if not getattr(rnn_cell, '_enable_caching_device', False):
3044    return None
3045  # Don't set a caching device when running in a loop, since it is possible that
3046  # train steps could be wrapped in a tf.while_loop. In that scenario caching
3047  # prevents forward computations in loop iterations from re-reading the
3048  # updated weights.
3049  if control_flow_util.IsInWhileLoop(ops.get_default_graph()):
3050    logging.warn('Variable read device caching has been disabled because the '
3051                 'RNN is in tf.while_loop loop context, which will cause '
3052                 'reading stalled value in forward path. This could slow down '
3053                 'the training due to duplicated variable reads. Please '
3054                 'consider updating your code to remove tf.while_loop if '
3055                 'possible.')
3056    return None
3057  if (rnn_cell._dtype_policy.compute_dtype !=
3058      rnn_cell._dtype_policy.variable_dtype):
3059    logging.warn('Variable read device caching has been disabled since it '
3060                 'doesn\'t work with the mixed precision API. This is '
3061                 'likely to cause a slowdown for RNN training due to '
3062                 'duplicated read of variable for each timestep, which '
3063                 'will be significant in a multi remote worker setting. '
3064                 'Please consider disabling mixed precision API if '
3065                 'the performance has been affected.')
3066    return None
3067  # Cache the value on the device that access the variable.
3068  return lambda op: op.device
3069
3070
3071def _config_for_enable_caching_device(rnn_cell):
3072  """Return the dict config for RNN cell wrt to enable_caching_device field.
3073
3074  Since enable_caching_device is a internal implementation detail for speed up
3075  the RNN variable read when running on the multi remote worker setting, we
3076  don't want this config to be serialized constantly in the JSON. We will only
3077  serialize this field when a none default value is used to create the cell.
3078  Args:
3079    rnn_cell: the RNN cell for serialize.
3080
3081  Returns:
3082    A dict which contains the JSON config for enable_caching_device value or
3083    empty dict if the enable_caching_device value is same as the default value.
3084  """
3085  default_enable_caching_device = ops.executing_eagerly_outside_functions()
3086  if rnn_cell._enable_caching_device != default_enable_caching_device:
3087    return {'enable_caching_device': rnn_cell._enable_caching_device}
3088  return {}
3089