1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Recurrent layers for TF 2.0.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import uuid
22
23from tensorflow.python.eager import context
24from tensorflow.python.eager import function
25from tensorflow.python.eager.context import get_device_name
26from tensorflow.python.framework import config
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import device
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.keras import activations
32from tensorflow.python.keras import backend as K
33from tensorflow.python.keras.engine.input_spec import InputSpec
34from tensorflow.python.keras.layers import recurrent
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import gen_cudnn_rnn_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import nn
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import variables
42from tensorflow.python.platform import sysconfig
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.util.tf_export import keras_export
45
46
47# The following string constants are used by Defun approach for unified backend
48# of LSTM and GRU.
49_FUNCTION_API_NAME_ATTRIBUTE = 'api_implements'
50_FUNCTION_DEVICE_ATTRIBUTE = 'api_preferred_device'
51_CPU_DEVICE_NAME = 'CPU'
52_GPU_DEVICE_NAME = 'GPU'
53
54# The following number constants are used to represent the runtime of the defun
55# backend function. Since the CPU/GPU implementation are mathematically same, we
56# need some signal for the function to indicate which function is executed. This
57# is for testing purpose to verify the correctness of swapping backend function.
58_RUNTIME_UNKNOWN = 0
59_RUNTIME_CPU = 1
60_RUNTIME_GPU = 2
61
62_CUDNN_AVAILABLE_MSG = 'Layer %s will use cuDNN kernels when running on GPU.'
63_CUDNN_NOT_AVAILABLE_MSG = ('Layer %s will not use cuDNN kernels since it '
64                            'doesn\'t meet the criteria. It will '
65                            'use a generic GPU kernel as fallback when running '
66                            'on GPU.')
67
68
69def _use_new_code():
70  return False
71
72
73# TODO(b/169707691): The wrapper can be removed if TFLite doesn't need to rely
74# on supportive attributes from LSTM/GRU.
75class _DefunWrapper(object):
76  """A wrapper with no deep copy of the Defun in LSTM/GRU layer."""
77
78  def __init__(self, time_major, go_backwards, layer_name):
79    self.time_major = time_major
80    self.go_backwards = go_backwards
81    self.layer_name = layer_name
82    if self.layer_name not in ['lstm', 'gru']:
83      raise ValueError('Defun wrapper only applies to LSTM and GRU layer, '
84                       'but given {}'.format(self.layer_name))
85    # The first two attributes are added to support TFLite use case.
86    supportive_attributes = {
87        'time_major': self.time_major,
88        'go_backwards': self.go_backwards,
89        _FUNCTION_API_NAME_ATTRIBUTE: self.layer_name + '_' + str(uuid.uuid4())
90    }
91    if self.layer_name == 'lstm':
92      layer_func = lstm_with_backend_selection
93    else:
94      layer_func = gru_with_backend_selection
95
96    self.defun_layer = function.defun_with_attributes(
97        layer_func,
98        attributes=supportive_attributes,
99        autograph=False)
100
101  def __deepcopy__(self, memo):
102    new_wrapper = type(self)(
103        self.time_major, self.go_backwards, self.layer_name)
104    memo[id(self)] = new_wrapper
105    return new_wrapper
106
107
108@keras_export('keras.layers.GRUCell', v1=[])
109class GRUCell(recurrent.GRUCell):
110  """Cell class for the GRU layer.
111
112  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
113  for details about the usage of RNN API.
114
115  This class processes one step within the whole time sequence input, whereas
116  `tf.keras.layer.GRU` processes the whole sequence.
117
118  For example:
119
120  >>> inputs = tf.random.normal([32, 10, 8])
121  >>> rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4))
122  >>> output = rnn(inputs)
123  >>> print(output.shape)
124  (32, 4)
125  >>> rnn = tf.keras.layers.RNN(
126  ...    tf.keras.layers.GRUCell(4),
127  ...    return_sequences=True,
128  ...    return_state=True)
129  >>> whole_sequence_output, final_state = rnn(inputs)
130  >>> print(whole_sequence_output.shape)
131  (32, 10, 4)
132  >>> print(final_state.shape)
133  (32, 4)
134
135  Args:
136    units: Positive integer, dimensionality of the output space.
137    activation: Activation function to use. Default: hyperbolic tangent
138      (`tanh`). If you pass None, no activation is applied
139      (ie. "linear" activation: `a(x) = x`).
140    recurrent_activation: Activation function to use for the recurrent step.
141      Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
142      applied (ie. "linear" activation: `a(x) = x`).
143    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
144    kernel_initializer: Initializer for the `kernel` weights matrix,
145      used for the linear transformation of the inputs. Default:
146      `glorot_uniform`.
147    recurrent_initializer: Initializer for the `recurrent_kernel`
148      weights matrix, used for the linear transformation of the recurrent state.
149      Default: `orthogonal`.
150    bias_initializer: Initializer for the bias vector. Default: `zeros`.
151    kernel_regularizer: Regularizer function applied to the `kernel` weights
152      matrix. Default: `None`.
153    recurrent_regularizer: Regularizer function applied to the
154      `recurrent_kernel` weights matrix. Default: `None`.
155    bias_regularizer: Regularizer function applied to the bias vector. Default:
156      `None`.
157    kernel_constraint: Constraint function applied to the `kernel` weights
158      matrix. Default: `None`.
159    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
160      weights matrix. Default: `None`.
161    bias_constraint: Constraint function applied to the bias vector. Default:
162      `None`.
163    dropout: Float between 0 and 1. Fraction of the units to drop for the
164      linear transformation of the inputs. Default: 0.
165    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
166      the linear transformation of the recurrent state. Default: 0.
167    reset_after: GRU convention (whether to apply reset gate after or
168      before matrix multiplication). False = "before",
169      True = "after" (default and CuDNN compatible).
170
171  Call arguments:
172    inputs: A 2D tensor, with shape of `[batch, feature]`.
173    states: A 2D tensor with shape of `[batch, units]`, which is the state from
174      the previous time step. For timestep 0, the initial state provided by user
175      will be feed to cell.
176    training: Python boolean indicating whether the layer should behave in
177      training mode or in inference mode. Only relevant when `dropout` or
178      `recurrent_dropout` is used.
179  """
180
181  def __init__(self,
182               units,
183               activation='tanh',
184               recurrent_activation='sigmoid',
185               use_bias=True,
186               kernel_initializer='glorot_uniform',
187               recurrent_initializer='orthogonal',
188               bias_initializer='zeros',
189               kernel_regularizer=None,
190               recurrent_regularizer=None,
191               bias_regularizer=None,
192               kernel_constraint=None,
193               recurrent_constraint=None,
194               bias_constraint=None,
195               dropout=0.,
196               recurrent_dropout=0.,
197               reset_after=True,
198               **kwargs):
199    super(GRUCell, self).__init__(
200        units,
201        activation=activation,
202        recurrent_activation=recurrent_activation,
203        use_bias=use_bias,
204        kernel_initializer=kernel_initializer,
205        recurrent_initializer=recurrent_initializer,
206        bias_initializer=bias_initializer,
207        kernel_regularizer=kernel_regularizer,
208        recurrent_regularizer=recurrent_regularizer,
209        bias_regularizer=bias_regularizer,
210        kernel_constraint=kernel_constraint,
211        recurrent_constraint=recurrent_constraint,
212        bias_constraint=bias_constraint,
213        dropout=dropout,
214        recurrent_dropout=recurrent_dropout,
215        implementation=kwargs.pop('implementation', 2),
216        reset_after=reset_after,
217        **kwargs)
218
219
220@keras_export('keras.layers.GRU', v1=[])
221class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
222  """Gated Recurrent Unit - Cho et al. 2014.
223
224  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
225  for details about the usage of RNN API.
226
227  Based on available runtime hardware and constraints, this layer
228  will choose different implementations (cuDNN-based or pure-TensorFlow)
229  to maximize the performance. If a GPU is available and all
230  the arguments to the layer meet the requirement of the CuDNN kernel
231  (see below for details), the layer will use a fast cuDNN implementation.
232
233  The requirements to use the cuDNN implementation are:
234
235  1. `activation` == `tanh`
236  2. `recurrent_activation` == `sigmoid`
237  3. `recurrent_dropout` == 0
238  4. `unroll` is `False`
239  5. `use_bias` is `True`
240  6. `reset_after` is `True`
241  7. Inputs, if use masking, are strictly right-padded.
242  8. Eager execution is enabled in the outermost context.
243
244  There are two variants of the GRU implementation. The default one is based on
245  [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to hidden
246  state before matrix multiplication. The other one is based on
247  [original](https://arxiv.org/abs/1406.1078v1) and has the order reversed.
248
249  The second variant is compatible with CuDNNGRU (GPU-only) and allows
250  inference on CPU. Thus it has separate biases for `kernel` and
251  `recurrent_kernel`. To use this variant, set `'reset_after'=True` and
252  `recurrent_activation='sigmoid'`.
253
254  For example:
255
256  >>> inputs = tf.random.normal([32, 10, 8])
257  >>> gru = tf.keras.layers.GRU(4)
258  >>> output = gru(inputs)
259  >>> print(output.shape)
260  (32, 4)
261  >>> gru = tf.keras.layers.GRU(4, return_sequences=True, return_state=True)
262  >>> whole_sequence_output, final_state = gru(inputs)
263  >>> print(whole_sequence_output.shape)
264  (32, 10, 4)
265  >>> print(final_state.shape)
266  (32, 4)
267
268  Args:
269    units: Positive integer, dimensionality of the output space.
270    activation: Activation function to use.
271      Default: hyperbolic tangent (`tanh`).
272      If you pass `None`, no activation is applied
273      (ie. "linear" activation: `a(x) = x`).
274    recurrent_activation: Activation function to use
275      for the recurrent step.
276      Default: sigmoid (`sigmoid`).
277      If you pass `None`, no activation is applied
278      (ie. "linear" activation: `a(x) = x`).
279    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
280    kernel_initializer: Initializer for the `kernel` weights matrix,
281      used for the linear transformation of the inputs. Default:
282      `glorot_uniform`.
283    recurrent_initializer: Initializer for the `recurrent_kernel`
284       weights matrix, used for the linear transformation of the recurrent
285       state. Default: `orthogonal`.
286    bias_initializer: Initializer for the bias vector. Default: `zeros`.
287    kernel_regularizer: Regularizer function applied to the `kernel` weights
288      matrix. Default: `None`.
289    recurrent_regularizer: Regularizer function applied to the
290      `recurrent_kernel` weights matrix. Default: `None`.
291    bias_regularizer: Regularizer function applied to the bias vector. Default:
292      `None`.
293    activity_regularizer: Regularizer function applied to the output of the
294      layer (its "activation"). Default: `None`.
295    kernel_constraint: Constraint function applied to the `kernel` weights
296      matrix. Default: `None`.
297    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
298      weights matrix. Default: `None`.
299    bias_constraint: Constraint function applied to the bias vector. Default:
300      `None`.
301    dropout: Float between 0 and 1. Fraction of the units to drop for the linear
302      transformation of the inputs. Default: 0.
303    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
304      the linear transformation of the recurrent state. Default: 0.
305    return_sequences: Boolean. Whether to return the last output
306      in the output sequence, or the full sequence. Default: `False`.
307    return_state: Boolean. Whether to return the last state in addition to the
308      output. Default: `False`.
309    go_backwards: Boolean (default `False`).
310      If True, process the input sequence backwards and return the
311      reversed sequence.
312    stateful: Boolean (default False). If True, the last state
313      for each sample at index i in a batch will be used as initial
314      state for the sample of index i in the following batch.
315    unroll: Boolean (default False).
316      If True, the network will be unrolled,
317      else a symbolic loop will be used.
318      Unrolling can speed-up a RNN,
319      although it tends to be more memory-intensive.
320      Unrolling is only suitable for short sequences.
321    time_major: The shape format of the `inputs` and `outputs` tensors.
322      If True, the inputs and outputs will be in shape
323      `[timesteps, batch, feature]`, whereas in the False case, it will be
324      `[batch, timesteps, feature]`. Using `time_major = True` is a bit more
325      efficient because it avoids transposes at the beginning and end of the
326      RNN calculation. However, most TensorFlow data is batch-major, so by
327      default this function accepts input and emits output in batch-major
328      form.
329    reset_after: GRU convention (whether to apply reset gate after or
330      before matrix multiplication). False = "before",
331      True = "after" (default and CuDNN compatible).
332
333  Call arguments:
334    inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
335    mask: Binary tensor of shape `[samples, timesteps]` indicating whether
336      a given timestep should be masked  (optional, defaults to `None`).
337      An individual `True` entry indicates that the corresponding timestep
338      should be utilized, while a `False` entry indicates that the
339      corresponding timestep should be ignored.
340    training: Python boolean indicating whether the layer should behave in
341      training mode or in inference mode. This argument is passed to the cell
342      when calling it. This is only relevant if `dropout` or
343      `recurrent_dropout` is used  (optional, defaults to `None`).
344    initial_state: List of initial state tensors to be passed to the first
345      call of the cell  (optional, defaults to `None` which causes creation
346      of zero-filled initial state tensors).
347  """
348
349  def __init__(self,
350               units,
351               activation='tanh',
352               recurrent_activation='sigmoid',
353               use_bias=True,
354               kernel_initializer='glorot_uniform',
355               recurrent_initializer='orthogonal',
356               bias_initializer='zeros',
357               kernel_regularizer=None,
358               recurrent_regularizer=None,
359               bias_regularizer=None,
360               activity_regularizer=None,
361               kernel_constraint=None,
362               recurrent_constraint=None,
363               bias_constraint=None,
364               dropout=0.,
365               recurrent_dropout=0.,
366               return_sequences=False,
367               return_state=False,
368               go_backwards=False,
369               stateful=False,
370               unroll=False,
371               time_major=False,
372               reset_after=True,
373               **kwargs):
374    # return_runtime is a flag for testing, which shows the real backend
375    # implementation chosen by grappler in graph mode.
376    self._return_runtime = kwargs.pop('return_runtime', False)
377
378    super(GRU, self).__init__(
379        units,
380        activation=activation,
381        recurrent_activation=recurrent_activation,
382        use_bias=use_bias,
383        kernel_initializer=kernel_initializer,
384        recurrent_initializer=recurrent_initializer,
385        bias_initializer=bias_initializer,
386        kernel_regularizer=kernel_regularizer,
387        recurrent_regularizer=recurrent_regularizer,
388        bias_regularizer=bias_regularizer,
389        activity_regularizer=activity_regularizer,
390        kernel_constraint=kernel_constraint,
391        recurrent_constraint=recurrent_constraint,
392        bias_constraint=bias_constraint,
393        dropout=dropout,
394        recurrent_dropout=recurrent_dropout,
395        implementation=kwargs.pop('implementation', 2),
396        return_sequences=return_sequences,
397        return_state=return_state,
398        go_backwards=go_backwards,
399        stateful=stateful,
400        unroll=unroll,
401        time_major=time_major,
402        reset_after=reset_after,
403        **kwargs)
404    # GPU kernel uses following setting by default and not configurable.
405    self._could_use_gpu_kernel = (
406        self.activation in (activations.tanh, nn.tanh) and
407        self.recurrent_activation in (activations.sigmoid, nn.sigmoid) and
408        recurrent_dropout == 0 and not unroll and use_bias and
409        reset_after and ops.executing_eagerly_outside_functions())
410    if config.list_logical_devices('GPU'):
411      # Only show the message when there is GPU available, user will not care
412      # about the cuDNN if there isn't any GPU.
413      if self._could_use_gpu_kernel:
414        logging.debug(_CUDNN_AVAILABLE_MSG % self.name)
415      else:
416        logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name)
417
418    if _use_new_code():
419      self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'gru')
420
421  def call(self, inputs, mask=None, training=None, initial_state=None):
422    # The input should be dense, padded with zeros. If a ragged input is fed
423    # into the layer, it is padded and the row lengths are used for masking.
424    inputs, row_lengths = K.convert_inputs_if_ragged(inputs)
425    is_ragged_input = (row_lengths is not None)
426    self._validate_args_if_ragged(is_ragged_input, mask)
427
428    # GRU does not support constants. Ignore it during process.
429    inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None)
430
431    if isinstance(mask, list):
432      mask = mask[0]
433
434    input_shape = K.int_shape(inputs)
435    timesteps = input_shape[0] if self.time_major else input_shape[1]
436
437    # TODO(b/156447398) Investigate why the cuDNN kernel fails with ragged
438    # inputs.
439    if is_ragged_input or not self._could_use_gpu_kernel:
440      kwargs = {'training': training}
441      self._maybe_reset_cell_dropout_mask(self.cell)
442
443      def step(cell_inputs, cell_states):
444        return self.cell(cell_inputs, cell_states, **kwargs)
445
446      last_output, outputs, states = K.rnn(
447          step,
448          inputs,
449          initial_state,
450          constants=None,
451          go_backwards=self.go_backwards,
452          mask=mask,
453          unroll=self.unroll,
454          input_length=row_lengths if row_lengths is not None else timesteps,
455          time_major=self.time_major,
456          zero_output_for_mask=self.zero_output_for_mask)
457      # This is a dummy tensor for testing purpose.
458      runtime = _runtime(_RUNTIME_UNKNOWN)
459    else:
460      last_output, outputs, runtime, states = self._defun_gru_call(
461          inputs, initial_state, training, mask, row_lengths)
462
463    if self.stateful:
464      updates = [state_ops.assign(self.states[0], states[0])]
465      self.add_update(updates)
466
467    if self.return_sequences:
468      output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths)
469    else:
470      output = last_output
471
472    if self.return_state:
473      return [output] + list(states)
474    elif self._return_runtime:
475      return output, runtime
476    else:
477      return output
478
479  def _defun_gru_call(self, inputs, initial_state, training, mask,
480                      sequence_lengths):
481    # Use the new defun approach for backend implementation swap.
482    # Note that different implementations need to have same function
483    # signature, eg, the tensor parameters need to have same shape and dtypes.
484
485    self.reset_dropout_mask()
486    dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
487    if dropout_mask is not None:
488      inputs = inputs * dropout_mask[0]
489
490    if _use_new_code():
491      gru_kwargs = {
492          'inputs': inputs,
493          'init_h': _read_variable_value(initial_state[0]),
494          'kernel': _read_variable_value(self.cell.kernel),
495          'recurrent_kernel': _read_variable_value(self.cell.recurrent_kernel),
496          'bias': _read_variable_value(self.cell.bias),
497          'mask': mask,
498          'time_major': self.time_major,
499          'go_backwards': self.go_backwards,
500          'sequence_lengths': sequence_lengths,
501          'zero_output_for_mask': self.zero_output_for_mask
502      }
503      (last_output, outputs, new_h,
504       runtime) = self._defun_wrapper.defun_layer(**gru_kwargs)
505    else:
506      gpu_gru_kwargs = {
507          'inputs': inputs,
508          'init_h': _read_variable_value(initial_state[0]),
509          'kernel': _read_variable_value(self.cell.kernel),
510          'recurrent_kernel': _read_variable_value(self.cell.recurrent_kernel),
511          'bias': _read_variable_value(self.cell.bias),
512          'mask': mask,
513          'time_major': self.time_major,
514          'go_backwards': self.go_backwards,
515          'sequence_lengths': sequence_lengths
516      }
517      normal_gru_kwargs = gpu_gru_kwargs.copy()
518      normal_gru_kwargs.update({
519          'zero_output_for_mask': self.zero_output_for_mask,
520      })
521
522      if context.executing_eagerly():
523        device_type = _get_context_device_type()
524        can_use_gpu = (
525            # Either user specified GPU or unspecified but GPU is available.
526            (device_type == _GPU_DEVICE_NAME or
527             (device_type is None and config.list_logical_devices('GPU'))) and
528            (mask is None or is_cudnn_supported_inputs(mask, self.time_major)))
529        # Under eager context, check the device placement and prefer the
530        if can_use_gpu:
531          last_output, outputs, new_h, runtime = gpu_gru(**gpu_gru_kwargs)
532        else:
533          last_output, outputs, new_h, runtime = standard_gru(
534              **normal_gru_kwargs)
535      else:
536        last_output, outputs, new_h, runtime = gru_with_backend_selection(
537            **normal_gru_kwargs)
538
539    states = [new_h]
540    return last_output, outputs, runtime, states
541
542
543def standard_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask,
544                 time_major, go_backwards, sequence_lengths,
545                 zero_output_for_mask):
546  """GRU with standard kernel implementation.
547
548  This implementation can be run on all types of hardware.
549
550  This implementation lifts out all the layer weights and make them function
551  parameters. It has same number of tensor input params as the CuDNN
552  counterpart. The RNN step logic has been simplified, eg dropout and mask is
553  removed since CuDNN implementation does not support that.
554
555  Args:
556    inputs: Input tensor of GRU layer.
557    init_h: Initial state tensor for the cell output.
558    kernel: Weights for cell kernel.
559    recurrent_kernel: Weights for cell recurrent kernel.
560    bias: Weights for cell kernel bias and recurrent bias. The bias contains the
561      combined input_bias and recurrent_bias.
562    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
563      a given timestep should be masked. An individual `True` entry indicates
564      that the corresponding timestep should be utilized, while a `False` entry
565      indicates that the corresponding timestep should be ignored.
566    time_major: Boolean, whether the inputs are in the format of
567      [time, batch, feature] or [batch, time, feature].
568    go_backwards: Boolean (default False). If True, process the input sequence
569      backwards and return the reversed sequence.
570    sequence_lengths: The lengths of all sequences coming from a variable length
571      input, such as ragged tensors. If the input has a fixed timestep size,
572      this should be None.
573    zero_output_for_mask: Boolean, whether to output zero for masked timestep.
574
575  Returns:
576    last_output: output tensor for the last timestep, which has shape
577      [batch, units].
578    outputs: output tensor for all timesteps, which has shape
579      [batch, time, units].
580    state_0: the cell output, which has same shape as init_h.
581    runtime: constant string tensor which indicate real runtime hardware. This
582      value is for testing purpose and should be used by user.
583  """
584  input_shape = K.int_shape(inputs)
585  timesteps = input_shape[0] if time_major else input_shape[1]
586
587  input_bias, recurrent_bias = array_ops.unstack(bias)
588
589  def step(cell_inputs, cell_states):
590    """Step function that will be used by Keras RNN backend."""
591    h_tm1 = cell_states[0]
592
593    # inputs projected by all gate matrices at once
594    matrix_x = K.dot(cell_inputs, kernel)
595    matrix_x = K.bias_add(matrix_x, input_bias)
596
597    x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=1)
598
599    # hidden state projected by all gate matrices at once
600    matrix_inner = K.dot(h_tm1, recurrent_kernel)
601    matrix_inner = K.bias_add(matrix_inner, recurrent_bias)
602
603    recurrent_z, recurrent_r, recurrent_h = array_ops.split(matrix_inner, 3,
604                                                            axis=1)
605    z = nn.sigmoid(x_z + recurrent_z)
606    r = nn.sigmoid(x_r + recurrent_r)
607    hh = nn.tanh(x_h + r * recurrent_h)
608
609    # previous and candidate state mixed by update gate
610    h = z * h_tm1 + (1 - z) * hh
611    return h, [h]
612
613  last_output, outputs, new_states = K.rnn(
614      step,
615      inputs, [init_h],
616      constants=None,
617      unroll=False,
618      time_major=time_major,
619      mask=mask,
620      go_backwards=go_backwards,
621      input_length=sequence_lengths
622      if sequence_lengths is not None else timesteps,
623      zero_output_for_mask=zero_output_for_mask)
624  return last_output, outputs, new_states[0], _runtime(_RUNTIME_CPU)
625
626
627def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
628            go_backwards, sequence_lengths):
629  """GRU with CuDNN implementation which is only available for GPU."""
630  if not time_major and mask is None:
631    inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
632    seq_axis, batch_axis = (0, 1)
633  else:
634    seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
635  # For init_h, cuDNN expects one more dim of num_layers before or after batch
636  # dim for time major or batch major inputs respectively
637  init_h = array_ops.expand_dims(init_h, axis=seq_axis)
638
639  weights = array_ops.split(kernel, 3, axis=1)
640  weights += array_ops.split(recurrent_kernel, 3, axis=1)
641  # Note that the bias was initialized as shape (2, 3 * units), flat it into
642  # (6 * units)
643  bias = array_ops.split(K.flatten(bias), 6)
644
645  if sysconfig.get_build_info()['is_cuda_build']:
646    # Note that the gate order for CuDNN is different from the canonical format.
647    # canonical format is [z, r, h], whereas CuDNN is [r, z, h]. The swap need
648    # to be done for kernel, recurrent_kernel, input_bias, recurrent_bias.
649    # z is update gate weights.
650    # r is reset gate weights.
651    # h is output gate weights.
652    weights[0], weights[1] = weights[1], weights[0]
653    weights[3], weights[4] = weights[4], weights[3]
654    bias[0], bias[1] = bias[1], bias[0]
655    bias[3], bias[4] = bias[4], bias[3]
656
657  params = _canonical_to_params(
658      weights=weights,
659      biases=bias,
660      shape=constant_op.constant([-1]),
661      transpose_weights=True)
662
663  if mask is not None:
664    sequence_lengths = calculate_sequence_by_mask(mask, time_major)
665
666  if sequence_lengths is not None:
667    if go_backwards:
668      # Three reversals are required. E.g.,
669      # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked
670      # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
671      # output_from_cudnn = [6, 5, 4, 0, 0]
672      # expected_output = [0, 0, 6, 5 ,4]
673      inputs = array_ops.reverse_sequence_v2(
674          inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
675    outputs, h, _, _, _ = gen_cudnn_rnn_ops.CudnnRNNV3(
676        input=inputs,
677        input_h=init_h,
678        input_c=0,
679        params=params,
680        is_training=True,
681        rnn_mode='gru',
682        sequence_lengths=sequence_lengths,
683        time_major=time_major)
684    if go_backwards:
685      outputs = array_ops.reverse_sequence_v2(
686          outputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
687      outputs = array_ops.reverse(outputs, axis=[seq_axis])
688  else:
689    if go_backwards:
690      # Reverse axis 0 since the input is already convert to time major.
691      inputs = array_ops.reverse(inputs, axis=[0])
692    outputs, h, _, _ = gen_cudnn_rnn_ops.CudnnRNN(
693        input=inputs, input_h=init_h, input_c=0, params=params,
694        is_training=True, rnn_mode='gru')
695
696  last_output = outputs[-1]
697  if not time_major and mask is None:
698    outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
699  h = array_ops.squeeze(h, axis=seq_axis)
700
701  # In the case of variable length input, the cudnn kernel will fill zeros for
702  # the output, whereas the default keras behavior is to bring over the previous
703  # output for t-1, so that in the return_sequence=False case, user can quickly
704  # get the final effect output instead just 0s at the last timestep.
705  # In order to mimic the default keras behavior, we copy the final h state as
706  # the last_output, since it is numerically same as the output.
707  if mask is not None:
708    last_output = h
709
710  return last_output, outputs, h, _runtime(_RUNTIME_GPU)
711
712
713def gru_with_backend_selection(inputs, init_h, kernel, recurrent_kernel, bias,
714                               mask, time_major, go_backwards, sequence_lengths,
715                               zero_output_for_mask):
716  """Call the GRU with optimized backend kernel selection.
717
718  Under the hood, this function will create two TF function, one with the most
719  generic kernel and can run on all device condition, and the second one with
720  CuDNN specific kernel, which can only run on GPU.
721
722  The first function will be called with normal_lstm_params, while the second
723  function is not called, but only registered in the graph. The Grappler will
724  do the proper graph rewrite and swap the optimized TF function based on the
725  device placement.
726
727  Args:
728    inputs: Input tensor of GRU layer.
729    init_h: Initial state tensor for the cell output.
730    kernel: Weights for cell kernel.
731    recurrent_kernel: Weights for cell recurrent kernel.
732    bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
733      is used in this case.
734    mask: Boolean tensor for mask out the steps within sequence.
735      An individual `True` entry indicates that the corresponding timestep
736      should be utilized, while a `False` entry indicates that the corresponding
737      timestep should be ignored.
738    time_major: Boolean, whether the inputs are in the format of
739      [time, batch, feature] or [batch, time, feature].
740    go_backwards: Boolean (default False). If True, process the input sequence
741      backwards and return the reversed sequence.
742    sequence_lengths: The lengths of all sequences coming from a variable length
743      input, such as ragged tensors. If the input has a fixed timestep size,
744      this should be None.
745    zero_output_for_mask: Boolean, whether to output zero for masked timestep.
746
747  Returns:
748    List of output tensors, same as standard_gru.
749  """
750  params = {
751      'inputs': inputs,
752      'init_h': init_h,
753      'kernel': kernel,
754      'recurrent_kernel': recurrent_kernel,
755      'bias': bias,
756      'mask': mask,
757      'time_major': time_major,
758      'go_backwards': go_backwards,
759      'sequence_lengths': sequence_lengths,
760      'zero_output_for_mask': zero_output_for_mask,
761  }
762
763  def gpu_gru_with_fallback(inputs, init_h, kernel, recurrent_kernel, bias,
764                            mask, time_major, go_backwards, sequence_lengths,
765                            zero_output_for_mask):
766    """Use CuDNN kernel when mask is none or strictly right padded."""
767    if mask is None:
768      return gpu_gru(
769          inputs=inputs,
770          init_h=init_h,
771          kernel=kernel,
772          recurrent_kernel=recurrent_kernel,
773          bias=bias,
774          mask=mask,
775          time_major=time_major,
776          go_backwards=go_backwards,
777          sequence_lengths=sequence_lengths)
778
779    def cudnn_gru_fn():
780      return gpu_gru(
781          inputs=inputs,
782          init_h=init_h,
783          kernel=kernel,
784          recurrent_kernel=recurrent_kernel,
785          bias=bias,
786          mask=mask,
787          time_major=time_major,
788          go_backwards=go_backwards,
789          sequence_lengths=sequence_lengths)
790
791    def standard_gru_fn():
792      return standard_gru(
793          inputs=inputs,
794          init_h=init_h,
795          kernel=kernel,
796          recurrent_kernel=recurrent_kernel,
797          bias=bias,
798          mask=mask,
799          time_major=time_major,
800          go_backwards=go_backwards,
801          sequence_lengths=sequence_lengths,
802          zero_output_for_mask=zero_output_for_mask)
803
804    return control_flow_ops.cond(
805        is_cudnn_supported_inputs(mask, time_major),
806        true_fn=cudnn_gru_fn,
807        false_fn=standard_gru_fn)
808
809  if _use_new_code():
810    # Chooses the implementation dynamically based on the running device.
811    (last_output, outputs, new_h,
812     runtime) = control_flow_ops.execute_fn_for_device(
813         {
814             _CPU_DEVICE_NAME: lambda: standard_gru(**params),
815             _GPU_DEVICE_NAME: lambda: gpu_gru_with_fallback(**params)
816         }, lambda: standard_gru(**params))
817  else:
818    # Each time a `tf.function` is called, we will give it a unique
819    # identifiable API name, so that Grappler won't get confused when it
820    # sees multiple GRU layers added into same graph, and it will be able
821    # to pair up the different implementations across them.
822    api_name = 'gru_' + str(uuid.uuid4())
823    supportive_attribute = {
824        'time_major': time_major,
825        'go_backwards': go_backwards,
826    }
827    defun_standard_gru = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
828                                                 standard_gru,
829                                                 supportive_attribute)
830    defun_gpu_gru = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
831                                            gpu_gru_with_fallback,
832                                            supportive_attribute)
833
834    # Call the normal GRU impl and register the CuDNN impl function. The
835    # grappler will kick in during session execution to optimize the graph.
836    last_output, outputs, new_h, runtime = defun_standard_gru(**params)
837    _function_register(defun_gpu_gru, **params)
838
839  return last_output, outputs, new_h, runtime
840
841
842@keras_export('keras.layers.LSTMCell', v1=[])
843class LSTMCell(recurrent.LSTMCell):
844  """Cell class for the LSTM layer.
845
846  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
847  for details about the usage of RNN API.
848
849  This class processes one step within the whole time sequence input, whereas
850  `tf.keras.layer.LSTM` processes the whole sequence.
851
852  For example:
853
854  >>> inputs = tf.random.normal([32, 10, 8])
855  >>> rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4))
856  >>> output = rnn(inputs)
857  >>> print(output.shape)
858  (32, 4)
859  >>> rnn = tf.keras.layers.RNN(
860  ...    tf.keras.layers.LSTMCell(4),
861  ...    return_sequences=True,
862  ...    return_state=True)
863  >>> whole_seq_output, final_memory_state, final_carry_state = rnn(inputs)
864  >>> print(whole_seq_output.shape)
865  (32, 10, 4)
866  >>> print(final_memory_state.shape)
867  (32, 4)
868  >>> print(final_carry_state.shape)
869  (32, 4)
870
871  Args:
872    units: Positive integer, dimensionality of the output space.
873    activation: Activation function to use. Default: hyperbolic tangent
874      (`tanh`). If you pass `None`, no activation is applied (ie. "linear"
875      activation: `a(x) = x`).
876    recurrent_activation: Activation function to use for the recurrent step.
877      Default: sigmoid (`sigmoid`). If you pass `None`, no activation is applied
878      (ie. "linear" activation: `a(x) = x`).
879    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
880    kernel_initializer: Initializer for the `kernel` weights matrix, used for
881      the linear transformation of the inputs. Default: `glorot_uniform`.
882    recurrent_initializer: Initializer for the `recurrent_kernel` weights
883      matrix, used for the linear transformation of the recurrent state.
884      Default: `orthogonal`.
885    bias_initializer: Initializer for the bias vector. Default: `zeros`.
886    unit_forget_bias: Boolean (default `True`). If True, add 1 to the bias of
887      the forget gate at initialization. Setting it to true will also force
888      `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
889        al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
890    kernel_regularizer: Regularizer function applied to the `kernel` weights
891      matrix. Default: `None`.
892    recurrent_regularizer: Regularizer function applied to
893      the `recurrent_kernel` weights matrix. Default: `None`.
894    bias_regularizer: Regularizer function applied to the bias vector. Default:
895      `None`.
896    kernel_constraint: Constraint function applied to the `kernel` weights
897      matrix. Default: `None`.
898    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
899      weights matrix. Default: `None`.
900    bias_constraint: Constraint function applied to the bias vector. Default:
901      `None`.
902    dropout: Float between 0 and 1. Fraction of the units to drop for the linear
903      transformation of the inputs. Default: 0.
904    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
905      the linear transformation of the recurrent state. Default: 0.
906
907  Call arguments:
908    inputs: A 2D tensor, with shape of `[batch, feature]`.
909    states: List of 2 tensors that corresponding to the cell's units. Both of
910      them have shape `[batch, units]`, the first tensor is the memory state
911      from previous time step, the second tensor is the carry state from
912      previous time step. For timestep 0, the initial state provided by user
913      will be feed to cell.
914    training: Python boolean indicating whether the layer should behave in
915      training mode or in inference mode. Only relevant when `dropout` or
916      `recurrent_dropout` is used.
917  """
918
919  def __init__(self,
920               units,
921               activation='tanh',
922               recurrent_activation='sigmoid',
923               use_bias=True,
924               kernel_initializer='glorot_uniform',
925               recurrent_initializer='orthogonal',
926               bias_initializer='zeros',
927               unit_forget_bias=True,
928               kernel_regularizer=None,
929               recurrent_regularizer=None,
930               bias_regularizer=None,
931               kernel_constraint=None,
932               recurrent_constraint=None,
933               bias_constraint=None,
934               dropout=0.,
935               recurrent_dropout=0.,
936               **kwargs):
937    super(LSTMCell, self).__init__(
938        units,
939        activation=activation,
940        recurrent_activation=recurrent_activation,
941        use_bias=use_bias,
942        kernel_initializer=kernel_initializer,
943        recurrent_initializer=recurrent_initializer,
944        bias_initializer=bias_initializer,
945        unit_forget_bias=unit_forget_bias,
946        kernel_regularizer=kernel_regularizer,
947        recurrent_regularizer=recurrent_regularizer,
948        bias_regularizer=bias_regularizer,
949        kernel_constraint=kernel_constraint,
950        recurrent_constraint=recurrent_constraint,
951        bias_constraint=bias_constraint,
952        dropout=dropout,
953        recurrent_dropout=recurrent_dropout,
954        implementation=kwargs.pop('implementation', 2),
955        **kwargs)
956
957
958@keras_export('keras.layers.LSTM', v1=[])
959class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
960  """Long Short-Term Memory layer - Hochreiter 1997.
961
962  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
963  for details about the usage of RNN API.
964
965  Based on available runtime hardware and constraints, this layer
966  will choose different implementations (cuDNN-based or pure-TensorFlow)
967  to maximize the performance. If a GPU is available and all
968  the arguments to the layer meet the requirement of the CuDNN kernel
969  (see below for details), the layer will use a fast cuDNN implementation.
970
971  The requirements to use the cuDNN implementation are:
972
973  1. `activation` == `tanh`
974  2. `recurrent_activation` == `sigmoid`
975  3. `recurrent_dropout` == 0
976  4. `unroll` is `False`
977  5. `use_bias` is `True`
978  6. Inputs, if use masking, are strictly right-padded.
979  7. Eager execution is enabled in the outermost context.
980
981  For example:
982
983  >>> inputs = tf.random.normal([32, 10, 8])
984  >>> lstm = tf.keras.layers.LSTM(4)
985  >>> output = lstm(inputs)
986  >>> print(output.shape)
987  (32, 4)
988  >>> lstm = tf.keras.layers.LSTM(4, return_sequences=True, return_state=True)
989  >>> whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
990  >>> print(whole_seq_output.shape)
991  (32, 10, 4)
992  >>> print(final_memory_state.shape)
993  (32, 4)
994  >>> print(final_carry_state.shape)
995  (32, 4)
996
997  Args:
998    units: Positive integer, dimensionality of the output space.
999    activation: Activation function to use.
1000      Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation
1001      is applied (ie. "linear" activation: `a(x) = x`).
1002    recurrent_activation: Activation function to use for the recurrent step.
1003      Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
1004      applied (ie. "linear" activation: `a(x) = x`).
1005    use_bias: Boolean (default `True`), whether the layer uses a bias vector.
1006    kernel_initializer: Initializer for the `kernel` weights matrix, used for
1007      the linear transformation of the inputs. Default: `glorot_uniform`.
1008    recurrent_initializer: Initializer for the `recurrent_kernel` weights
1009      matrix, used for the linear transformation of the recurrent state.
1010      Default: `orthogonal`.
1011    bias_initializer: Initializer for the bias vector. Default: `zeros`.
1012    unit_forget_bias: Boolean (default `True`). If True, add 1 to the bias of
1013      the forget gate at initialization. Setting it to true will also force
1014      `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
1015          al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
1016    kernel_regularizer: Regularizer function applied to the `kernel` weights
1017      matrix. Default: `None`.
1018    recurrent_regularizer: Regularizer function applied to the
1019      `recurrent_kernel` weights matrix. Default: `None`.
1020    bias_regularizer: Regularizer function applied to the bias vector. Default:
1021      `None`.
1022    activity_regularizer: Regularizer function applied to the output of the
1023      layer (its "activation"). Default: `None`.
1024    kernel_constraint: Constraint function applied to the `kernel` weights
1025      matrix. Default: `None`.
1026    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
1027      weights matrix. Default: `None`.
1028    bias_constraint: Constraint function applied to the bias vector. Default:
1029      `None`.
1030    dropout: Float between 0 and 1. Fraction of the units to drop for the linear
1031      transformation of the inputs. Default: 0.
1032    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
1033      the linear transformation of the recurrent state. Default: 0.
1034    return_sequences: Boolean. Whether to return the last output. in the output
1035      sequence, or the full sequence. Default: `False`.
1036    return_state: Boolean. Whether to return the last state in addition to the
1037      output. Default: `False`.
1038    go_backwards: Boolean (default `False`). If True, process the input sequence
1039      backwards and return the reversed sequence.
1040    stateful: Boolean (default `False`). If True, the last state for each sample
1041      at index i in a batch will be used as initial state for the sample of
1042      index i in the following batch.
1043    time_major: The shape format of the `inputs` and `outputs` tensors.
1044      If True, the inputs and outputs will be in shape
1045      `[timesteps, batch, feature]`, whereas in the False case, it will be
1046      `[batch, timesteps, feature]`. Using `time_major = True` is a bit more
1047      efficient because it avoids transposes at the beginning and end of the
1048      RNN calculation. However, most TensorFlow data is batch-major, so by
1049      default this function accepts input and emits output in batch-major
1050      form.
1051    unroll: Boolean (default `False`). If True, the network will be unrolled,
1052      else a symbolic loop will be used. Unrolling can speed-up a RNN, although
1053      it tends to be more memory-intensive. Unrolling is only suitable for short
1054      sequences.
1055
1056  Call arguments:
1057    inputs: A 3D tensor with shape `[batch, timesteps, feature]`.
1058    mask: Binary tensor of shape `[batch, timesteps]` indicating whether
1059      a given timestep should be masked (optional, defaults to `None`).
1060      An individual `True` entry indicates that the corresponding timestep
1061      should be utilized, while a `False` entry indicates that the corresponding
1062      timestep should be ignored.
1063    training: Python boolean indicating whether the layer should behave in
1064      training mode or in inference mode. This argument is passed to the cell
1065      when calling it. This is only relevant if `dropout` or
1066      `recurrent_dropout` is used (optional, defaults to `None`).
1067    initial_state: List of initial state tensors to be passed to the first
1068      call of the cell (optional, defaults to `None` which causes creation
1069      of zero-filled initial state tensors).
1070  """
1071
1072  def __init__(self,
1073               units,
1074               activation='tanh',
1075               recurrent_activation='sigmoid',
1076               use_bias=True,
1077               kernel_initializer='glorot_uniform',
1078               recurrent_initializer='orthogonal',
1079               bias_initializer='zeros',
1080               unit_forget_bias=True,
1081               kernel_regularizer=None,
1082               recurrent_regularizer=None,
1083               bias_regularizer=None,
1084               activity_regularizer=None,
1085               kernel_constraint=None,
1086               recurrent_constraint=None,
1087               bias_constraint=None,
1088               dropout=0.,
1089               recurrent_dropout=0.,
1090               return_sequences=False,
1091               return_state=False,
1092               go_backwards=False,
1093               stateful=False,
1094               time_major=False,
1095               unroll=False,
1096               **kwargs):
1097    # return_runtime is a flag for testing, which shows the real backend
1098    # implementation chosen by grappler in graph mode.
1099    self.return_runtime = kwargs.pop('return_runtime', False)
1100
1101    super(LSTM, self).__init__(
1102        units,
1103        activation=activation,
1104        recurrent_activation=recurrent_activation,
1105        use_bias=use_bias,
1106        kernel_initializer=kernel_initializer,
1107        recurrent_initializer=recurrent_initializer,
1108        bias_initializer=bias_initializer,
1109        unit_forget_bias=unit_forget_bias,
1110        kernel_regularizer=kernel_regularizer,
1111        recurrent_regularizer=recurrent_regularizer,
1112        bias_regularizer=bias_regularizer,
1113        activity_regularizer=activity_regularizer,
1114        kernel_constraint=kernel_constraint,
1115        recurrent_constraint=recurrent_constraint,
1116        bias_constraint=bias_constraint,
1117        dropout=dropout,
1118        recurrent_dropout=recurrent_dropout,
1119        implementation=kwargs.pop('implementation', 2),
1120        return_sequences=return_sequences,
1121        return_state=return_state,
1122        go_backwards=go_backwards,
1123        stateful=stateful,
1124        time_major=time_major,
1125        unroll=unroll,
1126        **kwargs)
1127
1128    self.state_spec = [
1129        InputSpec(shape=(None, dim)) for dim in (self.units, self.units)
1130    ]
1131    self._could_use_gpu_kernel = (
1132        self.activation in (activations.tanh, nn.tanh) and
1133        self.recurrent_activation in (activations.sigmoid, nn.sigmoid) and
1134        recurrent_dropout == 0 and not unroll and use_bias and
1135        ops.executing_eagerly_outside_functions())
1136    if config.list_logical_devices('GPU'):
1137      # Only show the message when there is GPU available, user will not care
1138      # about the cuDNN if there isn't any GPU.
1139      if self._could_use_gpu_kernel:
1140        logging.debug(_CUDNN_AVAILABLE_MSG % self.name)
1141      else:
1142        logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name)
1143
1144    if _use_new_code():
1145      self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'lstm')
1146
1147  def call(self, inputs, mask=None, training=None, initial_state=None):
1148    # The input should be dense, padded with zeros. If a ragged input is fed
1149    # into the layer, it is padded and the row lengths are used for masking.
1150    inputs, row_lengths = K.convert_inputs_if_ragged(inputs)
1151    is_ragged_input = (row_lengths is not None)
1152    self._validate_args_if_ragged(is_ragged_input, mask)
1153
1154    # LSTM does not support constants. Ignore it during process.
1155    inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None)
1156
1157    if isinstance(mask, list):
1158      mask = mask[0]
1159
1160    input_shape = K.int_shape(inputs)
1161    timesteps = input_shape[0] if self.time_major else input_shape[1]
1162
1163    # TODO(b/156447398) Investigate why the cuDNN kernel fails with ragged
1164    # inputs.
1165    if is_ragged_input or not self._could_use_gpu_kernel:
1166      # Fall back to use the normal LSTM.
1167      kwargs = {'training': training}
1168      self._maybe_reset_cell_dropout_mask(self.cell)
1169
1170      def step(inputs, states):
1171        return self.cell(inputs, states, **kwargs)
1172
1173      last_output, outputs, states = K.rnn(
1174          step,
1175          inputs,
1176          initial_state,
1177          constants=None,
1178          go_backwards=self.go_backwards,
1179          mask=mask,
1180          unroll=self.unroll,
1181          input_length=row_lengths if row_lengths is not None else timesteps,
1182          time_major=self.time_major,
1183          zero_output_for_mask=self.zero_output_for_mask)
1184      runtime = _runtime(_RUNTIME_UNKNOWN)
1185    else:
1186      # Use the new defun approach for backend implementation swap.
1187      # Note that different implementations need to have same function
1188      # signature, eg, the tensor parameters need to have same shape and dtypes.
1189      # Since the CuDNN has an extra set of bias, those bias will be passed to
1190      # both normal and CuDNN implementations.
1191      self.reset_dropout_mask()
1192      dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
1193      if dropout_mask is not None:
1194        inputs = inputs * dropout_mask[0]
1195      if _use_new_code():
1196        lstm_kwargs = {
1197            'inputs':
1198                inputs,
1199            'init_h':
1200                _read_variable_value(initial_state[0]),
1201            'init_c':
1202                _read_variable_value(initial_state[1]),
1203            'kernel':
1204                _read_variable_value(self.cell.kernel),
1205            'recurrent_kernel':
1206                _read_variable_value(self.cell.recurrent_kernel),
1207            'bias':
1208                _read_variable_value(self.cell.bias),
1209            'mask':
1210                mask,
1211            'time_major':
1212                self.time_major,
1213            'go_backwards':
1214                self.go_backwards,
1215            'sequence_lengths':
1216                row_lengths,
1217            'zero_output_for_mask':
1218                self.zero_output_for_mask,
1219        }
1220        (last_output, outputs, new_h, new_c,
1221         runtime) = self._defun_wrapper.defun_layer(**lstm_kwargs)
1222      else:
1223        gpu_lstm_kwargs = {
1224            'inputs':
1225                inputs,
1226            'init_h':
1227                _read_variable_value(initial_state[0]),
1228            'init_c':
1229                _read_variable_value(initial_state[1]),
1230            'kernel':
1231                _read_variable_value(self.cell.kernel),
1232            'recurrent_kernel':
1233                _read_variable_value(self.cell.recurrent_kernel),
1234            'bias':
1235                _read_variable_value(self.cell.bias),
1236            'mask':
1237                mask,
1238            'time_major':
1239                self.time_major,
1240            'go_backwards':
1241                self.go_backwards,
1242            'sequence_lengths':
1243                row_lengths
1244        }
1245        normal_lstm_kwargs = gpu_lstm_kwargs.copy()
1246        normal_lstm_kwargs.update({
1247            'zero_output_for_mask': self.zero_output_for_mask,
1248        })
1249
1250        if context.executing_eagerly():
1251          device_type = _get_context_device_type()
1252          can_use_gpu = (
1253              # Either user specified GPU or unspecified but GPU is available.
1254              (device_type == _GPU_DEVICE_NAME or
1255               (device_type is None and config.list_logical_devices('GPU'))) and
1256              (mask is None or
1257               is_cudnn_supported_inputs(mask, self.time_major)))
1258          # Under eager context, check the device placement and prefer the
1259          # GPU implementation when GPU is available.
1260          if can_use_gpu:
1261            last_output, outputs, new_h, new_c, runtime = gpu_lstm(
1262                **gpu_lstm_kwargs)
1263          else:
1264            last_output, outputs, new_h, new_c, runtime = standard_lstm(
1265                **normal_lstm_kwargs)
1266        else:
1267          (last_output, outputs, new_h, new_c,
1268           runtime) = lstm_with_backend_selection(**normal_lstm_kwargs)
1269
1270      states = [new_h, new_c]
1271
1272    if self.stateful:
1273      updates = [
1274          state_ops.assign(self_state, state)
1275          for self_state, state in zip(self.states, states)
1276      ]
1277      self.add_update(updates)
1278
1279    if self.return_sequences:
1280      output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths)
1281    else:
1282      output = last_output
1283
1284    if self.return_state:
1285      return [output] + list(states)
1286    elif self.return_runtime:
1287      return output, runtime
1288    else:
1289      return output
1290
1291
1292def _canonical_to_params(weights, biases, shape, transpose_weights=False):
1293  """Utility function convert variable to CuDNN compatible parameter.
1294
1295  Note that Keras weights for kernels are different from the CuDNN format. Eg.:
1296
1297  ```
1298    Keras                 CuDNN
1299    [[0, 1, 2],  <--->  [[0, 2, 4],
1300     [3, 4, 5]]          [1, 3, 5]]
1301  ```
1302
1303  If the input weights need to be in a unified format, then set
1304  `transpose_weights=True` to convert the weights.
1305
1306  Args:
1307    weights: list of weights for the individual kernels and recurrent kernels.
1308    biases: list of biases for individual gate.
1309    shape: the shape for the converted variables that will be feed to CuDNN.
1310    transpose_weights: boolean, whether to transpose the weights.
1311
1312  Returns:
1313    The converted weights that can be feed to CuDNN ops as param.
1314  """
1315  def convert(w):
1316    return array_ops.transpose(w) if transpose_weights else w
1317
1318  weights = [array_ops.reshape(convert(x), shape) for x in weights]
1319  biases = [array_ops.reshape(x, shape) for x in biases]
1320  return array_ops.concat(weights + biases, axis=0)
1321
1322
1323def standard_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias,
1324                  mask, time_major, go_backwards, sequence_lengths,
1325                  zero_output_for_mask):
1326  """LSTM with standard kernel implementation.
1327
1328  This implementation can be run on all types for hardware.
1329
1330  This implementation lifts out all the layer weights and make them function
1331  parameters. It has same number of tensor input params as the CuDNN
1332  counterpart. The RNN step logic has been simplified, eg dropout and mask is
1333  removed since CuDNN implementation does not support that.
1334
1335  Note that the first half of the bias tensor should be ignored by this impl.
1336  The CuDNN impl need an extra set of input gate bias. In order to make the both
1337  function take same shape of parameter, that extra set of bias is also feed
1338  here.
1339
1340  Args:
1341    inputs: input tensor of LSTM layer.
1342    init_h: initial state tensor for the cell output.
1343    init_c: initial state tensor for the cell hidden state.
1344    kernel: weights for cell kernel.
1345    recurrent_kernel: weights for cell recurrent kernel.
1346    bias: weights for cell kernel bias and recurrent bias. Only recurrent bias
1347      is used in this case.
1348    mask: Boolean tensor for mask out the steps within sequence.
1349      An individual `True` entry indicates that the corresponding timestep
1350      should be utilized, while a `False` entry indicates that the corresponding
1351      timestep should be ignored.
1352    time_major: boolean, whether the inputs are in the format of
1353      [time, batch, feature] or [batch, time, feature].
1354    go_backwards: Boolean (default False). If True, process the input sequence
1355      backwards and return the reversed sequence.
1356    sequence_lengths: The lengths of all sequences coming from a variable length
1357      input, such as ragged tensors. If the input has a fixed timestep size,
1358      this should be None.
1359    zero_output_for_mask: Boolean, whether to output zero for masked timestep.
1360
1361  Returns:
1362    last_output: output tensor for the last timestep, which has shape
1363      [batch, units].
1364    outputs: output tensor for all timesteps, which has shape
1365      [batch, time, units].
1366    state_0: the cell output, which has same shape as init_h.
1367    state_1: the cell hidden state, which has same shape as init_c.
1368    runtime: constant string tensor which indicate real runtime hardware. This
1369      value is for testing purpose and should be used by user.
1370  """
1371  input_shape = K.int_shape(inputs)
1372  timesteps = input_shape[0] if time_major else input_shape[1]
1373
1374  def step(cell_inputs, cell_states):
1375    """Step function that will be used by Keras RNN backend."""
1376    h_tm1 = cell_states[0]  # previous memory state
1377    c_tm1 = cell_states[1]  # previous carry state
1378
1379    z = K.dot(cell_inputs, kernel)
1380    z += K.dot(h_tm1, recurrent_kernel)
1381    z = K.bias_add(z, bias)
1382
1383    z0, z1, z2, z3 = array_ops.split(z, 4, axis=1)
1384
1385    i = nn.sigmoid(z0)
1386    f = nn.sigmoid(z1)
1387    c = f * c_tm1 + i * nn.tanh(z2)
1388    o = nn.sigmoid(z3)
1389
1390    h = o * nn.tanh(c)
1391    return h, [h, c]
1392
1393  last_output, outputs, new_states = K.rnn(
1394      step,
1395      inputs, [init_h, init_c],
1396      constants=None,
1397      unroll=False,
1398      time_major=time_major,
1399      mask=mask,
1400      go_backwards=go_backwards,
1401      input_length=(sequence_lengths
1402                    if sequence_lengths is not None else timesteps),
1403      zero_output_for_mask=zero_output_for_mask)
1404  return (last_output, outputs, new_states[0], new_states[1],
1405          _runtime(_RUNTIME_CPU))
1406
1407
1408def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
1409             time_major, go_backwards, sequence_lengths):
1410  """LSTM with either CuDNN or ROCm implementation which is only available for GPU.
1411
1412  Note that currently only right padded data is supported, or the result will be
1413  polluted by the unmasked data which should be filtered.
1414
1415  Args:
1416    inputs: Input tensor of LSTM layer.
1417    init_h: Initial state tensor for the cell output.
1418    init_c: Initial state tensor for the cell hidden state.
1419    kernel: Weights for cell kernel.
1420    recurrent_kernel: Weights for cell recurrent kernel.
1421    bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
1422      is used in this case.
1423    mask: Boolean tensor for mask out the steps within sequence.
1424      An individual `True` entry indicates that the corresponding timestep
1425      should be utilized, while a `False` entry indicates that the corresponding
1426      timestep should be ignored.
1427    time_major: Boolean, whether the inputs are in the format of [time, batch,
1428      feature] or [batch, time, feature].
1429    go_backwards: Boolean (default False). If True, process the input sequence
1430      backwards and return the reversed sequence.
1431    sequence_lengths: The lengths of all sequences coming from a variable length
1432      input, such as ragged tensors. If the input has a fixed timestep size,
1433      this should be None.
1434
1435  Returns:
1436    last_output: Output tensor for the last timestep, which has shape
1437      [batch, units].
1438    outputs: Output tensor for all timesteps, which has shape
1439      [batch, time, units].
1440    state_0: The cell output, which has same shape as init_h.
1441    state_1: The cell hidden state, which has same shape as init_c.
1442    runtime: Constant string tensor which indicate real runtime hardware. This
1443      value is for testing purpose and should not be used by user.
1444  """
1445  if not time_major and mask is None:
1446    inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
1447    seq_axis, batch_axis = (0, 1)
1448  else:
1449    seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1450  # For init_h and init_c, cuDNN expects one more dim of num_layers before or
1451  # after batch dim for time major or batch major inputs respectively
1452  init_h = array_ops.expand_dims(init_h, axis=seq_axis)
1453  init_c = array_ops.expand_dims(init_c, axis=seq_axis)
1454
1455  weights = array_ops.split(kernel, 4, axis=1)
1456  weights += array_ops.split(recurrent_kernel, 4, axis=1)
1457  # CuDNN has an extra set of bias for inputs, we disable them (setting to 0),
1458  # so that mathematically it is same as the canonical LSTM implementation.
1459  full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0)
1460
1461  if sysconfig.get_build_info()['is_rocm_build']:
1462    # ROCm MIOpen's weight sequence for LSTM is different from both canonical
1463    # and Cudnn format
1464    # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
1465    # i is input gate weights.
1466    # f is forget gate weights.
1467    # o is output gate weights.
1468    # c is cell gate weights.
1469    weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
1470    # full_bias is a tensor of shape (8*n,)
1471    full_bias = array_ops.split(full_bias, 8, axis=0)
1472    full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
1473
1474  params = _canonical_to_params(
1475      weights=weights,
1476      biases=array_ops.split(full_bias, 8),
1477      shape=constant_op.constant([-1]),
1478      transpose_weights=True)
1479
1480  if mask is not None:
1481    sequence_lengths = calculate_sequence_by_mask(mask, time_major)
1482
1483  if sequence_lengths is not None:
1484    if go_backwards:
1485      # Three reversals are required. E.g.,
1486      # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked
1487      # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1488      # output_from_cudnn = [6, 5, 4, 0, 0]
1489      # expected_output = [0, 0, 6, 5 ,4]
1490      inputs = array_ops.reverse_sequence_v2(
1491          inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
1492    outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV3(
1493        input=inputs,
1494        input_h=init_h,
1495        input_c=init_c,
1496        params=params,
1497        is_training=True,
1498        rnn_mode='lstm',
1499        sequence_lengths=sequence_lengths,
1500        time_major=time_major)
1501    if go_backwards:
1502      outputs = array_ops.reverse_sequence_v2(
1503          outputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
1504      outputs = array_ops.reverse(outputs, axis=[seq_axis])
1505  else:
1506    # # Fill the array with shape [batch] with value of max timesteps.
1507    # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
1508    #                                  array_ops.shape(inputs)[0])
1509    if go_backwards:
1510      # Reverse axis 0 since the input is already convert to time major.
1511      inputs = array_ops.reverse(inputs, axis=[0])
1512    outputs, h, c, _ = gen_cudnn_rnn_ops.CudnnRNN(
1513        input=inputs, input_h=init_h, input_c=init_c, params=params,
1514        is_training=True, rnn_mode='lstm')
1515
1516  last_output = outputs[-1]
1517  if not time_major and mask is None:
1518    outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
1519  h = array_ops.squeeze(h, axis=seq_axis)
1520  c = array_ops.squeeze(c, axis=seq_axis)
1521
1522  # In the case of variable length input, the cudnn kernel will fill zeros for
1523  # the output, whereas the default keras behavior is to bring over the previous
1524  # output for t-1, so that in the return_sequence=False case, user can quickly
1525  # get the final effect output instead just 0s at the last timestep.
1526  # In order to mimic the default keras behavior, we copy the final h state as
1527  # the last_output, since it is numerically same as the output.
1528  if mask is not None:
1529    last_output = h
1530  return last_output, outputs, h, c, _runtime(_RUNTIME_GPU)
1531
1532
1533def lstm_with_backend_selection(inputs, init_h, init_c, kernel,
1534                                recurrent_kernel, bias, mask, time_major,
1535                                go_backwards, sequence_lengths,
1536                                zero_output_for_mask):
1537  """Call the LSTM with optimized backend kernel selection.
1538
1539  Under the hood, this function will create two TF function, one with the most
1540  generic kernel and can run on all device condition, and the second one with
1541  CuDNN specific kernel, which can only run on GPU.
1542
1543  The first function will be called with normal_lstm_params, while the second
1544  function is not called, but only registered in the graph. The Grappler will
1545  do the proper graph rewrite and swap the optimized TF function based on the
1546  device placement.
1547
1548  Args:
1549    inputs: Input tensor of LSTM layer.
1550    init_h: Initial state tensor for the cell output.
1551    init_c: Initial state tensor for the cell hidden state.
1552    kernel: Weights for cell kernel.
1553    recurrent_kernel: Weights for cell recurrent kernel.
1554    bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
1555      is used in this case.
1556    mask: Boolean tensor for mask out the steps within sequence.
1557      An individual `True` entry indicates that the corresponding timestep
1558      should be utilized, while a `False` entry indicates that the corresponding
1559      timestep should be ignored.
1560    time_major: Boolean, whether the inputs are in the format of
1561      [time, batch, feature] or [batch, time, feature].
1562    go_backwards: Boolean (default False). If True, process the input sequence
1563      backwards and return the reversed sequence.
1564    sequence_lengths: The lengths of all sequences coming from a variable length
1565      input, such as ragged tensors. If the input has a fixed timestep size,
1566      this should be None.
1567    zero_output_for_mask: Boolean, whether to output zero for masked timestep.
1568
1569  Returns:
1570    List of output tensors, same as standard_lstm.
1571  """
1572  params = {
1573      'inputs': inputs,
1574      'init_h': init_h,
1575      'init_c': init_c,
1576      'kernel': kernel,
1577      'recurrent_kernel': recurrent_kernel,
1578      'bias': bias,
1579      'mask': mask,
1580      'time_major': time_major,
1581      'go_backwards': go_backwards,
1582      'sequence_lengths': sequence_lengths,
1583      'zero_output_for_mask': zero_output_for_mask,
1584  }
1585
1586  def gpu_lstm_with_fallback(inputs, init_h, init_c, kernel, recurrent_kernel,
1587                             bias, mask, time_major, go_backwards,
1588                             sequence_lengths, zero_output_for_mask):
1589    """Use CuDNN kernel when mask is none or strictly right padded."""
1590    if mask is None:
1591      return gpu_lstm(
1592          inputs=inputs,
1593          init_h=init_h,
1594          init_c=init_c,
1595          kernel=kernel,
1596          recurrent_kernel=recurrent_kernel,
1597          bias=bias,
1598          mask=mask,
1599          time_major=time_major,
1600          go_backwards=go_backwards,
1601          sequence_lengths=sequence_lengths)
1602
1603    def cudnn_lstm_fn():
1604      return gpu_lstm(
1605          inputs=inputs,
1606          init_h=init_h,
1607          init_c=init_c,
1608          kernel=kernel,
1609          recurrent_kernel=recurrent_kernel,
1610          bias=bias,
1611          mask=mask,
1612          time_major=time_major,
1613          go_backwards=go_backwards,
1614          sequence_lengths=sequence_lengths)
1615
1616    def stardard_lstm_fn():
1617      return standard_lstm(
1618          inputs=inputs,
1619          init_h=init_h,
1620          init_c=init_c,
1621          kernel=kernel,
1622          recurrent_kernel=recurrent_kernel,
1623          bias=bias,
1624          mask=mask,
1625          time_major=time_major,
1626          go_backwards=go_backwards,
1627          sequence_lengths=sequence_lengths,
1628          zero_output_for_mask=zero_output_for_mask)
1629
1630    return control_flow_ops.cond(
1631        is_cudnn_supported_inputs(mask, time_major),
1632        true_fn=cudnn_lstm_fn,
1633        false_fn=stardard_lstm_fn)
1634
1635  if _use_new_code():
1636    # Chooses the implementation dynamically based on the running device.
1637    (last_output, outputs, new_h, new_c,
1638     runtime) = control_flow_ops.execute_fn_for_device(
1639         {
1640             _CPU_DEVICE_NAME: lambda: standard_lstm(**params),
1641             _GPU_DEVICE_NAME: lambda: gpu_lstm_with_fallback(**params)
1642         }, lambda: standard_lstm(**params))
1643  else:
1644    # Each time a `tf.function` is called, we will give it a unique
1645    # identifiable API name, so that Grappler won't get confused when it
1646    # sees multiple LSTM layers added into same graph, and it will be able
1647    # to pair up the different implementations across them.
1648    api_name = 'lstm_' + str(uuid.uuid4())
1649    supportive_attribute = {
1650        'time_major': time_major,
1651        'go_backwards': go_backwards,
1652    }
1653    defun_standard_lstm = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
1654                                                  standard_lstm,
1655                                                  supportive_attribute)
1656    defun_gpu_lstm = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
1657                                             gpu_lstm_with_fallback,
1658                                             supportive_attribute)
1659
1660    # Call the normal LSTM impl and register the CuDNN impl function. The
1661    # grappler will kick in during session execution to optimize the graph.
1662    last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(**params)
1663    _function_register(defun_gpu_lstm, **params)
1664
1665  return last_output, outputs, new_h, new_c, runtime
1666
1667
1668def is_sequence_right_padded(mask):
1669  """Check the mask tensor and see if it right padded.
1670
1671  For CuDNN kernel, it uses the sequence length param to skip the tailing
1672  timestep. If the data is left padded, or not a strict right padding (has
1673  masked value in the middle of the sequence), then CuDNN kernel won't be work
1674  properly in those cases.
1675
1676  Left padded data: [[False, False, True, True, True]].
1677  Right padded data: [[True, True, True, False, False]].
1678  Mixture of mask/unmasked data: [[True, False, True, False, False]].
1679
1680  Note that for the mixed data example above, the actually data RNN should see
1681  are those 2 Trues (index 0 and 2), the index 1 False should be ignored and not
1682  pollute the internal states.
1683
1684  Args:
1685    mask: the Boolean tensor with shape [batch, timestep]
1686
1687  Returns:
1688    boolean scalar tensor, whether the mask is strictly right padded.
1689  """
1690  max_seq_length = array_ops.shape(mask)[1]
1691  count_of_true = math_ops.reduce_sum(math_ops.cast(mask, dtypes.int32), axis=1)
1692  right_padded_mask = array_ops.sequence_mask(
1693      count_of_true, maxlen=max_seq_length)
1694  return math_ops.reduce_all(math_ops.equal(mask, right_padded_mask))
1695
1696
1697def has_fully_masked_sequence(mask):
1698  # See https://github.com/tensorflow/tensorflow/issues/33148 for more details.
1699  # Cudnn kernel will error out if the input sequence contains any fully masked
1700  # data. We walk around this issue by rerouting the computation to standard
1701  # kernel, until the issue on cudnn side has been fixed.
1702  # For a fully masked sequence, it will contain all Falses. To make it easy to
1703  # check, we inverse the boolean, check if any of the sequence has all True.
1704  return math_ops.reduce_any(
1705      math_ops.reduce_all(
1706          math_ops.logical_not(mask),
1707          axis=1))
1708
1709
1710def is_cudnn_supported_inputs(mask, time_major):
1711  if time_major:
1712    mask = array_ops.transpose(mask)
1713
1714  return math_ops.logical_and(
1715      is_sequence_right_padded(mask),
1716      math_ops.logical_not(has_fully_masked_sequence(mask)))
1717
1718
1719def calculate_sequence_by_mask(mask, time_major):
1720  """Calculate the sequence length tensor (1-D) based on the masking tensor.
1721
1722  The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
1723  any timestep that should be masked, the corresponding field will be False.
1724  Consider the following example:
1725    a = [[True, True, False, False],
1726         [True, True, True, False]]
1727  It is a (2, 4) tensor, and the corresponding sequence length result should be
1728  1D tensor with value [2, 3]. Note that the masking tensor must be right
1729  padded that could be checked by, e.g., `is_sequence_right_padded()`.
1730
1731  Args:
1732    mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] if
1733      time_major=True.
1734    time_major: Boolean, which indicates whether the mask is time major or batch
1735      major.
1736  Returns:
1737    sequence_length: 1D int32 tensor.
1738  """
1739  timestep_index = 0 if time_major else 1
1740  return math_ops.reduce_sum(math_ops.cast(mask, dtypes.int32),
1741                             axis=timestep_index)
1742
1743
1744def _generate_defun_backend(unique_api_name, preferred_device, func,
1745                            supportive_attributes):
1746  function_attributes = {
1747      _FUNCTION_API_NAME_ATTRIBUTE: unique_api_name,
1748      _FUNCTION_DEVICE_ATTRIBUTE: preferred_device,
1749  }
1750  function_attributes.update(supportive_attributes)
1751  return function.defun_with_attributes(func=func,
1752                                        attributes=function_attributes,
1753                                        autograph=False)
1754
1755
1756def _get_context_device_type():
1757  """Parse the current context and return the device type, eg CPU/GPU."""
1758  current_device = get_device_name()
1759  if current_device is None:
1760    return None
1761  return device.DeviceSpec.from_string(current_device).device_type
1762
1763
1764def _runtime(runtime_name):
1765  with ops.device('/cpu:0'):
1766    return constant_op.constant(
1767        runtime_name, dtype=dtypes.float32, name='runtime')
1768
1769
1770def _read_variable_value(v):
1771  """Read the value of a variable if it is variable."""
1772  if isinstance(v, variables.Variable):
1773    return v.read_value()
1774  return v
1775
1776
1777def _function_register(func, *args, **kwargs):
1778  """Register a specialization of a `Function` into the graph.
1779
1780  This won't actually call the function with the inputs, and only put the
1781  function definition into graph. Register function with different input param
1782  will result into multiple version of functions registered in graph.
1783
1784  Args:
1785    func: the `Function` instance that generated by a @defun
1786    *args: input arguments for the Python function.
1787    **kwargs: input keyword arguments for the Python function.
1788
1789  Returns:
1790    a `ConcreteFunction` object specialized to inputs and execution context.
1791
1792  Raises:
1793    ValueError: When the input function is not a defun wrapped python function.
1794  """
1795  if not isinstance(func, function.Function):
1796    raise ValueError('Only defun function is allowed to be registered. '
1797                     'Got type: %s' % type(func))
1798  concrete_func = func.get_concrete_function(*args, **kwargs)
1799  concrete_func.add_to_graph()
1800  concrete_func.add_gradient_functions_to_graph()
1801  return concrete_func
1802