1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Recurrent layers for TF 2.0.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import uuid
22
23from tensorflow.python.eager import context
24from tensorflow.python.eager import function
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import device
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.keras import backend as K
30from tensorflow.python.keras.engine.input_spec import InputSpec
31from tensorflow.python.keras.layers import recurrent
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import gen_cudnn_rnn_ops
34from tensorflow.python.ops import state_ops
35from tensorflow.python.util.tf_export import keras_export
36
37
38# The following string constants are used by Defun approach for unified backend
39# of LSTM and GRU.
40_DEFUN_API_NAME_ATTRIBUTE = 'api_implements'
41_DEFUN_DEVICE_ATTRIBUTE = 'api_preferred_device'
42_CPU_DEVICE_NAME = 'CPU'
43_GPU_DEVICE_NAME = 'GPU'
44
45
46@keras_export('keras.layers.GRU', v1=[])
47class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
48  """Gated Recurrent Unit - Cho et al. 2014.
49
50  Based on available runtime hardware and constraints, this layer
51  will choose different implementations (cuDNN-based or pure-TensorFlow)
52  to maximize the performance. If a GPU is available and all
53  the arguments to the layer meet the requirement of the CuDNN kernel
54  (see below for details), the layer will use a fast cuDNN implementation.
55
56  The requirements to use the cuDNN implementation are:
57
58  1. `activation` == 'tanh'
59  2. `recurrent_activation` == 'sigmoid'
60  3. `recurrent_dropout` == 0
61  4. `unroll` is False
62  5. `use_bias` is True
63  6. `reset_after` is True
64  7. No use of masking.
65
66  There are two variants of the GRU implementation. The default one is based on
67  [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to hidden
68  state before matrix multiplication. The other one is based on
69  [original](https://arxiv.org/abs/1406.1078v1) and has the order reversed.
70
71  The second variant is compatible with CuDNNGRU (GPU-only) and allows
72  inference on CPU. Thus it has separate biases for `kernel` and
73  `recurrent_kernel`. To use this variant, set `'reset_after'=True` and
74  `recurrent_activation='sigmoid'`.
75
76  Arguments:
77    units: Positive integer, dimensionality of the output space.
78    activation: Activation function to use.
79      Default: hyperbolic tangent (`tanh`).
80      If you pass `None`, no activation is applied
81      (ie. "linear" activation: `a(x) = x`).
82    recurrent_activation: Activation function to use
83      for the recurrent step.
84      Default: sigmoid (`sigmoid`).
85      If you pass `None`, no activation is applied
86      (ie. "linear" activation: `a(x) = x`).
87    use_bias: Boolean, whether the layer uses a bias vector.
88    kernel_initializer: Initializer for the `kernel` weights matrix,
89      used for the linear transformation of the inputs.
90    recurrent_initializer: Initializer for the `recurrent_kernel`
91       weights matrix,
92       used for the linear transformation of the recurrent state.
93    bias_initializer: Initializer for the bias vector.
94    kernel_regularizer: Regularizer function applied to
95      the `kernel` weights matrix.
96    recurrent_regularizer: Regularizer function applied to
97      the `recurrent_kernel` weights matrix.
98    bias_regularizer: Regularizer function applied to the bias vector.
99    activity_regularizer: Regularizer function applied to
100      the output of the layer (its "activation")..
101    kernel_constraint: Constraint function applied to
102      the `kernel` weights matrix.
103    recurrent_constraint: Constraint function applied to
104      the `recurrent_kernel` weights matrix.
105    bias_constraint: Constraint function applied to the bias vector.
106    dropout: Float between 0 and 1.
107      Fraction of the units to drop for the linear transformation of the inputs.
108    recurrent_dropout: Float between 0 and 1.
109      Fraction of the units to drop for
110      the linear transformation of the recurrent state.
111    implementation: Implementation mode, either 1 or 2.
112      Mode 1 will structure its operations as a larger number of
113      smaller dot products and additions, whereas mode 2 will
114      batch them into fewer, larger operations. These modes will
115      have different performance profiles on different hardware and
116      for different applications.
117    return_sequences: Boolean. Whether to return the last output
118      in the output sequence, or the full sequence.
119    return_state: Boolean. Whether to return the last state
120      in addition to the output.
121    go_backwards: Boolean (default False).
122      If True, process the input sequence backwards and return the
123      reversed sequence.
124    stateful: Boolean (default False). If True, the last state
125      for each sample at index i in a batch will be used as initial
126      state for the sample of index i in the following batch.
127    unroll: Boolean (default False).
128      If True, the network will be unrolled,
129      else a symbolic loop will be used.
130      Unrolling can speed-up a RNN,
131      although it tends to be more memory-intensive.
132      Unrolling is only suitable for short sequences.
133    reset_after: GRU convention (whether to apply reset gate after or
134      before matrix multiplication). False = "before",
135      True = "after" (default and CuDNN compatible).
136
137  Call arguments:
138    inputs: A 3D tensor.
139    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
140      a given timestep should be masked.
141    training: Python boolean indicating whether the layer should behave in
142      training mode or in inference mode. This argument is passed to the cell
143      when calling it. This is only relevant if `dropout` or
144      `recurrent_dropout` is used.
145    initial_state: List of initial state tensors to be passed to the first
146      call of the cell.
147  """
148
149  def __init__(self,
150               units,
151               activation='tanh',
152               recurrent_activation='sigmoid',
153               use_bias=True,
154               kernel_initializer='glorot_uniform',
155               recurrent_initializer='orthogonal',
156               bias_initializer='zeros',
157               kernel_regularizer=None,
158               recurrent_regularizer=None,
159               bias_regularizer=None,
160               activity_regularizer=None,
161               kernel_constraint=None,
162               recurrent_constraint=None,
163               bias_constraint=None,
164               dropout=0.,
165               recurrent_dropout=0.,
166               implementation=1,
167               return_sequences=False,
168               return_state=False,
169               go_backwards=False,
170               stateful=False,
171               unroll=False,
172               time_major=False,
173               reset_after=True,
174               **kwargs):
175    # return_runtime is a flag for testing, which shows the real backend
176    # implementation chosen by grappler in graph mode.
177    self._return_runtime = kwargs.pop('return_runtime', False)
178
179    super(GRU, self).__init__(
180        units,
181        activation=activation,
182        recurrent_activation=recurrent_activation,
183        use_bias=use_bias,
184        kernel_initializer=kernel_initializer,
185        recurrent_initializer=recurrent_initializer,
186        bias_initializer=bias_initializer,
187        kernel_regularizer=kernel_regularizer,
188        recurrent_regularizer=recurrent_regularizer,
189        bias_regularizer=bias_regularizer,
190        activity_regularizer=activity_regularizer,
191        kernel_constraint=kernel_constraint,
192        recurrent_constraint=recurrent_constraint,
193        bias_constraint=bias_constraint,
194        dropout=dropout,
195        recurrent_dropout=recurrent_dropout,
196        implementation=implementation,
197        return_sequences=return_sequences,
198        return_state=return_state,
199        go_backwards=go_backwards,
200        stateful=stateful,
201        unroll=unroll,
202        time_major=time_major,
203        reset_after=reset_after,
204        **kwargs)
205    # CuDNN uses following setting by default and not configurable.
206    self.could_use_cudnn = (
207        activation == 'tanh' and recurrent_activation == 'sigmoid' and
208        recurrent_dropout == 0 and not unroll and use_bias and
209        reset_after)
210
211  def call(self, inputs, mask=None, training=None, initial_state=None):
212    # GRU does not support constants. Ignore it during process.
213    inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None)
214
215    if isinstance(mask, list):
216      mask = mask[0]
217
218    input_shape = K.int_shape(inputs)
219    timesteps = input_shape[0] if self.time_major else input_shape[1]
220
221    if mask is not None or not self.could_use_cudnn:
222      # CuDNN does not support masking, fall back to use the normal GRU.
223      kwargs = {'training': training}
224
225      def step(cell_inputs, cell_states):
226        return self.cell.call(cell_inputs, cell_states, **kwargs)
227
228      last_output, outputs, states = K.rnn(
229          step,
230          inputs,
231          initial_state,
232          constants=None,
233          go_backwards=self.go_backwards,
234          mask=mask,
235          unroll=self.unroll,
236          input_length=timesteps,
237          time_major=self.time_major,
238          zero_output_for_mask=self.zero_output_for_mask)
239      # This is a dummy tensor for testing purpose.
240      runtime = _runtime('unknown')
241    else:
242      last_output, outputs, runtime, states = self._defun_gru_call(
243          inputs, initial_state, training)
244
245    if self.stateful:
246      updates = [state_ops.assign(self.states[0], states[0])]
247      self.add_update(updates, inputs)
248
249    if self.return_sequences:
250      output = outputs
251    else:
252      output = last_output
253
254    if self.return_state:
255      return [output] + list(states)
256    elif self._return_runtime:
257      return output, runtime
258    else:
259      return output
260
261  def _defun_gru_call(self, inputs, initial_state, training):
262    # Use the new defun approach for backend implementation swap.
263    # Note that different implementations need to have same function
264    # signature, eg, the tensor parameters need to have same shape and dtypes.
265    if self.go_backwards:
266      # Reverse time axis.
267      inputs = K.reverse(inputs, 0 if self.time_major else 1)
268
269    self.reset_dropout_mask()
270    dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
271    if dropout_mask is not None:
272      inputs *= dropout_mask[0]
273    if context.executing_eagerly():
274      device_type = _get_context_device_type()
275      if device_type == _GPU_DEVICE_NAME or (
276          device_type is None and context.num_gpus() > 0):
277        # Under eager context, check the device placement and prefer the
278        # GPU implementation when GPU is available.
279        last_output, outputs, new_h, runtime = cudnn_gru(
280            inputs=inputs,
281            init_h=initial_state[0],
282            kernel=self.cell.kernel,
283            recurrent_kernel=self.cell.recurrent_kernel,
284            bias=self.cell.bias,
285            time_major=self.time_major)
286      else:
287        last_output, outputs, new_h, runtime = standard_gru(
288            inputs=inputs,
289            init_h=initial_state[0],
290            kernel=self.cell.kernel,
291            recurrent_kernel=self.cell.recurrent_kernel,
292            bias=self.cell.bias,
293            activation=self.activation,
294            recurrent_activation=self.recurrent_activation,
295            time_major=self.time_major)
296    else:
297      api_name = 'gru_' + str(uuid.uuid4())
298      defun_standard_gru = _generate_defun_backend(
299          api_name, _CPU_DEVICE_NAME, standard_gru)
300      defun_cudnn_gru = _generate_defun_backend(
301          api_name, _GPU_DEVICE_NAME, cudnn_gru)
302      # Call the normal GRU impl and register the CuDNN impl function. The
303      # grappler will kick in during session execution to optimize the graph.
304      last_output, outputs, new_h, runtime = defun_standard_gru(
305          inputs=inputs,
306          init_h=initial_state[0],
307          kernel=self.cell.kernel,
308          recurrent_kernel=self.cell.recurrent_kernel,
309          bias=self.cell.bias,
310          activation=self.activation,
311          recurrent_activation=self.recurrent_activation,
312          time_major=self.time_major)
313
314      function.register(defun_cudnn_gru, inputs, initial_state[0],
315                        self.cell.kernel, self.cell.recurrent_kernel,
316                        self.cell.bias, self.time_major)
317    states = [new_h]
318    return last_output, outputs, runtime, states
319
320
321def standard_gru(inputs, init_h, kernel, recurrent_kernel, bias, activation,
322                 recurrent_activation, time_major):
323  """GRU with standard kernel implementation.
324
325  This implementation can be run on all types of hardware.
326
327  This implementation lifts out all the layer weights and make them function
328  parameters. It has same number of tensor input params as the CuDNN
329  counterpart. The RNN step logic has been simplified, eg dropout and mask is
330  removed since CuDNN implementation does not support that.
331
332  Arguments:
333    inputs: input tensor of GRU layer.
334    init_h: initial state tensor for the cell output.
335    kernel: weights for cell kernel.
336    recurrent_kernel: weights for cell recurrent kernel.
337    bias: weights for cell kernel bias and recurrent bias. The bias contains the
338      combined input_bias and recurrent_bias.
339    activation: Activation function to use for output.
340    recurrent_activation: Activation function to use for hidden recurrent state.
341    time_major: boolean, whether the inputs are in the format of
342      [time, batch, feature] or [batch, time, feature].
343
344  Returns:
345    last_output: output tensor for the last timestep, which has shape
346      [batch, units].
347    outputs: output tensor for all timesteps, which has shape
348      [batch, time, units].
349    state_0: the cell output, which has same shape as init_h.
350    runtime: constant string tensor which indicate real runtime hardware. This
351      value is for testing purpose and should be used by user.
352  """
353  input_shape = K.int_shape(inputs)
354  timesteps = input_shape[0] if time_major else input_shape[1]
355
356  input_bias, recurrent_bias = array_ops.unstack(bias)
357
358  def step(cell_inputs, cell_states):
359    """Step function that will be used by Keras RNN backend."""
360    h_tm1 = cell_states[0]
361
362    # inputs projected by all gate matrices at once
363    matrix_x = K.dot(cell_inputs, kernel)
364    matrix_x = K.bias_add(matrix_x, input_bias)
365
366    x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=1)
367
368    # hidden state projected by all gate matrices at once
369    matrix_inner = K.dot(h_tm1, recurrent_kernel)
370    matrix_inner = K.bias_add(matrix_inner, recurrent_bias)
371
372    recurrent_z, recurrent_r, recurrent_h = array_ops.split(matrix_inner, 3,
373                                                            axis=1)
374    z = recurrent_activation(x_z + recurrent_z)
375    r = recurrent_activation(x_r + recurrent_r)
376    hh = activation(x_h + r * recurrent_h)
377
378    # previous and candidate state mixed by update gate
379    h = z * h_tm1 + (1 - z) * hh
380    return h, [h]
381
382  last_output, outputs, new_states = K.rnn(
383      step,
384      inputs, [init_h],
385      constants=None,
386      unroll=False,
387      time_major=time_major,
388      input_length=timesteps)
389  return last_output, outputs, new_states[0], _runtime('cpu')
390
391
392def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, time_major):
393  """GRU with CuDNN implementation which is only available for GPU."""
394  if not time_major:
395    inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
396  init_h = array_ops.expand_dims(init_h, axis=0)
397
398  weights = array_ops.split(kernel, 3, axis=1)
399  weights += array_ops.split(recurrent_kernel, 3, axis=1)
400  # Note that the bias was initialized as shape (2, 3 * units), flat it into
401  # (6 * units)
402  bias = array_ops.split(K.flatten(bias), 6)
403  # Note that the gate order for CuDNN is different from the canonical format.
404  # canonical format is [z, r, h], whereas CuDNN is [r, z, h]. The swap need to
405  # be done for kernel, recurrent_kernel, input_bias, recurrent_bias.
406  # z is update gate weights.
407  # r is reset gate weights.
408  # h is output gate weights.
409  weights[0], weights[1] = weights[1], weights[0]
410  weights[3], weights[4] = weights[4], weights[3]
411  bias[0], bias[1] = bias[1], bias[0]
412  bias[3], bias[4] = bias[4], bias[3]
413
414  params = _canonical_to_params(
415      weights=weights,
416      biases=bias,
417      shape=constant_op.constant([-1]),
418      transpose_weights=True)
419
420  outputs, h, _, _ = gen_cudnn_rnn_ops.cudnn_rnn(
421      inputs,
422      input_h=init_h,
423      input_c=0,
424      params=params,
425      is_training=True,
426      rnn_mode='gru')
427  last_output = outputs[-1]
428  if not time_major:
429    outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
430  h = h[0]
431  return last_output, outputs, h, _runtime('cudnn')
432
433
434@keras_export('keras.layers.LSTM', v1=[])
435class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
436  """Long Short-Term Memory layer - Hochreiter 1997.
437
438  Based on available runtime hardware and constraints, this layer
439  will choose different implementations (cuDNN-based or pure-TensorFlow)
440  to maximize the performance. If a GPU is available and all
441  the arguments to the layer meet the requirement of the CuDNN kernel
442  (see below for details), the layer will use a fast cuDNN implementation.
443
444  The requirements to use the cuDNN implementation are:
445
446  1. `activation` == 'tanh'
447  2. `recurrent_activation` == 'sigmoid'
448  3. `recurrent_dropout` == 0
449  4. `unroll` is False
450  5. `use_bias` is True
451  7. No use of masking.
452
453  Arguments:
454    units: Positive integer, dimensionality of the output space.
455    activation: Activation function to use.
456      Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation
457      is applied (ie. "linear" activation: `a(x) = x`).
458    recurrent_activation: Activation function to use for the recurrent step.
459      Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
460      applied (ie. "linear" activation: `a(x) = x`).
461    use_bias: Boolean, whether the layer uses a bias vector.
462    kernel_initializer: Initializer for the `kernel` weights matrix, used for
463      the linear transformation of the inputs..
464    recurrent_initializer: Initializer for the `recurrent_kernel` weights
465      matrix, used for the linear transformation of the recurrent state..
466    bias_initializer: Initializer for the bias vector.
467    unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate at
468      initialization. Setting it to true will also force
469      `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
470          al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
471    kernel_regularizer: Regularizer function applied to the `kernel` weights
472      matrix.
473    recurrent_regularizer: Regularizer function applied to the
474      `recurrent_kernel` weights matrix.
475    bias_regularizer: Regularizer function applied to the bias vector.
476    activity_regularizer: Regularizer function applied to the output of the
477      layer (its "activation")..
478    kernel_constraint: Constraint function applied to the `kernel` weights
479      matrix.
480    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
481      weights matrix.
482    bias_constraint: Constraint function applied to the bias vector.
483    dropout: Float between 0 and 1. Fraction of the units to drop for the linear
484      transformation of the inputs.
485    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
486      the linear transformation of the recurrent state.
487    implementation: Implementation mode, either 1 or 2. Mode 1 will structure
488      its operations as a larger number of smaller dot products and additions,
489      whereas mode 2 will batch them into fewer, larger operations. These modes
490      will have different performance profiles on different hardware and for
491      different applications.
492    return_sequences: Boolean. Whether to return the last output. in the output
493      sequence, or the full sequence.
494    return_state: Boolean. Whether to return the last state in addition to the
495      output.
496    go_backwards: Boolean (default False). If True, process the input sequence
497      backwards and return the reversed sequence.
498    stateful: Boolean (default False). If True, the last state for each sample
499      at index i in a batch will be used as initial state for the sample of
500      index i in the following batch.
501    unroll: Boolean (default False). If True, the network will be unrolled, else
502      a symbolic loop will be used. Unrolling can speed-up a RNN, although it
503      tends to be more memory-intensive. Unrolling is only suitable for short
504      sequences.
505
506  Call arguments:
507    inputs: A 3D tensor.
508    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
509      a given timestep should be masked.
510    training: Python boolean indicating whether the layer should behave in
511      training mode or in inference mode. This argument is passed to the cell
512      when calling it. This is only relevant if `dropout` or
513      `recurrent_dropout` is used.
514    initial_state: List of initial state tensors to be passed to the first
515      call of the cell.
516  """
517
518  def __init__(self,
519               units,
520               activation='tanh',
521               recurrent_activation='sigmoid',
522               use_bias=True,
523               kernel_initializer='glorot_uniform',
524               recurrent_initializer='orthogonal',
525               bias_initializer='zeros',
526               unit_forget_bias=True,
527               kernel_regularizer=None,
528               recurrent_regularizer=None,
529               bias_regularizer=None,
530               activity_regularizer=None,
531               kernel_constraint=None,
532               recurrent_constraint=None,
533               bias_constraint=None,
534               dropout=0.,
535               recurrent_dropout=0.,
536               implementation=1,
537               return_sequences=False,
538               return_state=False,
539               go_backwards=False,
540               stateful=False,
541               time_major=False,
542               unroll=False,
543               **kwargs):
544    # return_runtime is a flag for testing, which shows the real backend
545    # implementation chosen by grappler in graph mode.
546    self.return_runtime = kwargs.pop('return_runtime', False)
547
548    super(LSTM, self).__init__(
549        units,
550        activation=activation,
551        recurrent_activation=recurrent_activation,
552        use_bias=use_bias,
553        kernel_initializer=kernel_initializer,
554        recurrent_initializer=recurrent_initializer,
555        bias_initializer=bias_initializer,
556        unit_forget_bias=unit_forget_bias,
557        kernel_regularizer=kernel_regularizer,
558        recurrent_regularizer=recurrent_regularizer,
559        bias_regularizer=bias_regularizer,
560        activity_regularizer=activity_regularizer,
561        kernel_constraint=kernel_constraint,
562        recurrent_constraint=recurrent_constraint,
563        bias_constraint=bias_constraint,
564        dropout=dropout,
565        recurrent_dropout=recurrent_dropout,
566        implementation=implementation,
567        return_sequences=return_sequences,
568        return_state=return_state,
569        go_backwards=go_backwards,
570        stateful=stateful,
571        time_major=time_major,
572        unroll=unroll,
573        **kwargs)
574
575    self.state_spec = [
576        InputSpec(shape=(None, dim)) for dim in (self.units, self.units)
577    ]
578    self.could_use_cudnn = (
579        activation == 'tanh' and recurrent_activation == 'sigmoid' and
580        recurrent_dropout == 0 and not unroll and use_bias)
581
582  def call(self, inputs, mask=None, training=None, initial_state=None):
583    # LSTM does not support constants. Ignore it during process.
584    inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None)
585
586    if isinstance(mask, list):
587      mask = mask[0]
588
589    input_shape = K.int_shape(inputs)
590    timesteps = input_shape[0] if self.time_major else input_shape[1]
591
592    if mask is not None or not self.could_use_cudnn:
593      # CuDNN does not support masking, fall back to use the normal LSTM.
594      kwargs = {'training': training}
595
596      def step(inputs, states):
597        return self.cell.call(inputs, states, **kwargs)
598
599      last_output, outputs, states = K.rnn(
600          step,
601          inputs,
602          initial_state,
603          constants=None,
604          go_backwards=self.go_backwards,
605          mask=mask,
606          unroll=self.unroll,
607          input_length=timesteps,
608          time_major=self.time_major,
609          zero_output_for_mask=self.zero_output_for_mask)
610      runtime = _runtime('unknown')
611    else:
612      # Use the new defun approach for backend implementation swap.
613      # Note that different implementations need to have same function
614      # signature, eg, the tensor parameters need to have same shape and dtypes.
615      # Since the CuDNN has an extra set of bias, those bias will be passed to
616      # both normal and CuDNN implementations.
617      if self.go_backwards:
618        # Reverse time axis.
619        inputs = K.reverse(inputs, 0 if self.time_major else 1)
620
621      self.reset_dropout_mask()
622      dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
623      if dropout_mask is not None:
624        inputs *= dropout_mask[0]
625
626      if context.executing_eagerly():
627        device_type = _get_context_device_type()
628        if device_type == _GPU_DEVICE_NAME or (
629            device_type is None and context.num_gpus() > 0):
630          # Under eager context, check the device placement and prefer the
631          # GPU implementation when GPU is available.
632          last_output, outputs, new_h, new_c, runtime = cudnn_lstm(
633              inputs, initial_state[0], initial_state[1], self.cell.kernel,
634              self.cell.recurrent_kernel, self.cell.bias, self.time_major)
635        else:
636          last_output, outputs, new_h, new_c, runtime = standard_lstm(
637              inputs, initial_state[0], initial_state[1], self.cell.kernel,
638              self.cell.recurrent_kernel, self.cell.bias, self.activation,
639              self.recurrent_activation, self.time_major)
640      else:
641        # Each time a `tf.function` is called, we will give it a unique
642        # identifiable API name, so that Grappler won't get confused when it
643        # sees multiple LSTM layers added into same graph, and it will be able
644        # to pair up the different implementations across them.
645        api_name = 'lstm_' + str(uuid.uuid4())
646        defun_standard_lstm = _generate_defun_backend(
647            api_name, _CPU_DEVICE_NAME, standard_lstm)
648        defun_cudnn_lstm = _generate_defun_backend(
649            api_name, _GPU_DEVICE_NAME, cudnn_lstm)
650
651        # Call the normal LSTM impl and register the CuDNN impl function. The
652        # grappler will kick in during session execution to optimize the graph.
653        last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
654            inputs, initial_state[0], initial_state[1], self.cell.kernel,
655            self.cell.recurrent_kernel, self.cell.bias, self.activation,
656            self.recurrent_activation, self.time_major)
657
658        function.register(defun_cudnn_lstm, inputs, initial_state[0],
659                          initial_state[1], self.cell.kernel,
660                          self.cell.recurrent_kernel, self.cell.bias,
661                          self.time_major)
662      states = [new_h, new_c]
663
664    if self.stateful:
665      updates = []
666      for i in range(len(states)):
667        updates.append(state_ops.assign(self.states[i], states[i]))
668      self.add_update(updates, inputs)
669
670    if self.return_sequences:
671      output = outputs
672    else:
673      output = last_output
674
675    if self.return_state:
676      return [output] + list(states)
677    elif self.return_runtime:
678      return output, runtime
679    else:
680      return output
681
682
683def _canonical_to_params(weights, biases, shape, transpose_weights=False):
684  """Utility function convert variable to CuDNN compatible parameter.
685
686  Note that Keras weights for kernels are different from the CuDNN format. Eg.:
687
688  ```
689    Keras                 CuDNN
690    [[0, 1, 2],  <--->  [[0, 2, 4],
691     [3, 4, 5]]          [1, 3, 5]]
692  ```
693
694  If the input weights need to be in a unified format, then set
695  `transpose_weights=True` to convert the weights.
696
697  Args:
698    weights: list of weights for the individual kernels and recurrent kernels.
699    biases: list of biases for individual gate.
700    shape: the shape for the converted variables that will be feed to CuDNN.
701    transpose_weights: boolean, whether to transpose the weights.
702
703  Returns:
704    The converted weights that can be feed to CuDNN ops as param.
705  """
706  def convert(w):
707    return array_ops.transpose(w) if transpose_weights else w
708
709  weights = [array_ops.reshape(convert(x), shape) for x in weights]
710  biases = [array_ops.reshape(x, shape) for x in biases]
711  return array_ops.concat(weights + biases, axis=0)
712
713
714def standard_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias,
715                  activation, recurrent_activation, time_major):
716  """LSTM with standard kernel implementation.
717
718  This implementation can be run on all types for hardware.
719
720  This implementation lifts out all the layer weights and make them function
721  parameters. It has same number of tensor input params as the CuDNN
722  counterpart. The RNN step logic has been simplified, eg dropout and mask is
723  removed since CuDNN implementation does not support that.
724
725  Note that the first half of the bias tensor should be ignored by this impl.
726  The CuDNN impl need an extra set of input gate bias. In order to make the both
727  function take same shape of parameter, that extra set of bias is also feed
728  here.
729
730  Args:
731    inputs: input tensor of LSTM layer.
732    init_h: initial state tensor for the cell output.
733    init_c: initial state tensor for the cell hidden state.
734    kernel: weights for cell kernel.
735    recurrent_kernel: weights for cell recurrent kernel.
736    bias: weights for cell kernel bias and recurrent bias. Only recurrent bias
737      is used in this case.
738    activation: Activation function to use for output.
739    recurrent_activation: Activation function to use for hidden recurrent state.
740    time_major: boolean, whether the inputs are in the format of
741      [time, batch, feature] or [batch, time, feature].
742
743  Returns:
744    last_output: output tensor for the last timestep, which has shape
745      [batch, units].
746    outputs: output tensor for all timesteps, which has shape
747      [batch, time, units].
748    state_0: the cell output, which has same shape as init_h.
749    state_1: the cell hidden state, which has same shape as init_c.
750    runtime: constant string tensor which indicate real runtime hardware. This
751      value is for testing purpose and should be used by user.
752  """
753  input_shape = K.int_shape(inputs)
754  timesteps = input_shape[0] if time_major else input_shape[1]
755
756  def step(cell_inputs, cell_states):
757    """Step function that will be used by Keras RNN backend."""
758    h_tm1 = cell_states[0]  # previous memory state
759    c_tm1 = cell_states[1]  # previous carry state
760
761    z = K.dot(cell_inputs, kernel)
762    z += K.dot(h_tm1, recurrent_kernel)
763    z = K.bias_add(z, bias)
764
765    z0, z1, z2, z3 = array_ops.split(z, 4, axis=1)
766
767    i = recurrent_activation(z0)
768    f = recurrent_activation(z1)
769    c = f * c_tm1 + i * activation(z2)
770    o = recurrent_activation(z3)
771
772    h = o * activation(c)
773    return h, [h, c]
774
775  last_output, outputs, new_states = K.rnn(
776      step,
777      inputs, [init_h, init_c],
778      constants=None,
779      unroll=False,
780      time_major=time_major,
781      input_length=timesteps)
782  return last_output, outputs, new_states[0], new_states[1], _runtime('cpu')
783
784
785def cudnn_lstm(inputs, input_h, input_c, kernel, recurrent_kernel, bias,
786               time_major):
787  """LSTM with CuDNN implementation which is only available for GPU."""
788  if not time_major:
789    inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
790  input_h = array_ops.expand_dims(input_h, axis=0)
791  input_c = array_ops.expand_dims(input_c, axis=0)
792
793  weights = array_ops.split(kernel, 4, axis=1)
794  weights += array_ops.split(recurrent_kernel, 4, axis=1)
795  # CuDNN has an extra set of bias for inputs, we disable them (setting to 0),
796  # so that mathematically it is same as the canonical LSTM implementation.
797  full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0)
798
799  params = _canonical_to_params(
800      weights=weights,
801      biases=array_ops.split(full_bias, 8),
802      shape=constant_op.constant([-1]),
803      transpose_weights=True)
804
805  outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
806      inputs, input_h=input_h, input_c=input_c, params=params, is_training=True)
807  last_output = outputs[-1]
808  if not time_major:
809    outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
810  h = h[0]
811  c = c[0]
812
813  return last_output, outputs, h, c, _runtime('cudnn')
814
815
816def _generate_defun_backend(unique_api_name, preferred_device, func):
817  function_attributes = {
818      _DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
819      _DEFUN_DEVICE_ATTRIBUTE: preferred_device,
820  }
821  return function.defun_with_attributes(func=func,
822                                        attributes=function_attributes)
823
824
825def _get_context_device_type():
826  """Parse the current context and return the device type, eg CPU/GPU."""
827  current_device = context.context().device_name
828  if current_device is None:
829    return None
830  return device.DeviceSpec.from_string(current_device).device_type
831
832
833def _runtime(runtime_name):
834  with ops.device('/cpu:0'):
835    return constant_op.constant(
836        runtime_name, dtype=dtypes.string, name='runtime')
837