1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Recurrent layers backed by cuDNN.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.keras import backend as K
25from tensorflow.python.keras import constraints
26from tensorflow.python.keras import initializers
27from tensorflow.python.keras import regularizers
28from tensorflow.python.keras.engine.input_spec import InputSpec
29from tensorflow.python.keras.layers import recurrent_v2
30from tensorflow.python.keras.layers.recurrent import RNN
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import gen_cudnn_rnn_ops
33from tensorflow.python.ops import state_ops
34from tensorflow.python.util.tf_export import keras_export
35
36
37class _CuDNNRNN(RNN):
38  """Private base class for CuDNNGRU and CuDNNLSTM layers.
39
40  Arguments:
41    return_sequences: Boolean. Whether to return the last output
42        in the output sequence, or the full sequence.
43    return_state: Boolean. Whether to return the last state
44        in addition to the output.
45    go_backwards: Boolean (default False).
46        If True, process the input sequence backwards and return the
47        reversed sequence.
48    stateful: Boolean (default False). If True, the last state
49        for each sample at index i in a batch will be used as initial
50        state for the sample of index i in the following batch.
51    time_major: Boolean (default False). If true, the inputs and outputs will be
52        in shape `(timesteps, batch, ...)`, whereas in the False case, it will
53        be `(batch, timesteps, ...)`.
54  """
55
56  def __init__(self,
57               return_sequences=False,
58               return_state=False,
59               go_backwards=False,
60               stateful=False,
61               time_major=False,
62               **kwargs):
63    # We invoke the base layer's initializer directly here because we do not
64    # want to create RNN cell instance.
65    super(RNN, self).__init__(**kwargs)  # pylint: disable=bad-super-call
66    self.return_sequences = return_sequences
67    self.return_state = return_state
68    self.go_backwards = go_backwards
69    self.stateful = stateful
70    self.time_major = time_major
71    self.supports_masking = False
72    self.input_spec = [InputSpec(ndim=3)]
73    if hasattr(self.cell.state_size, '__len__'):
74      state_size = self.cell.state_size
75    else:
76      state_size = [self.cell.state_size]
77    self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
78    self.constants_spec = None
79    self._states = None
80    self._num_constants = None
81    self._num_inputs = None
82    self._vector_shape = constant_op.constant([-1])
83
84  def call(self, inputs, mask=None, training=None, initial_state=None):
85    if isinstance(mask, list):
86      mask = mask[0]
87    if mask is not None:
88      raise ValueError('Masking is not supported for CuDNN RNNs.')
89
90    # input shape: `(samples, time (padded with zeros), input_dim)`
91    # note that the .build() method of subclasses MUST define
92    # self.input_spec and self.state_spec with complete input shapes.
93    if isinstance(inputs, list):
94      initial_state = inputs[1:]
95      inputs = inputs[0]
96    elif initial_state is not None:
97      pass
98    elif self.stateful:
99      initial_state = self.states
100    else:
101      initial_state = self.get_initial_state(inputs)
102
103    if len(initial_state) != len(self.states):
104      raise ValueError('Layer has ' + str(len(self.states)) +
105                       ' states but was passed ' + str(len(initial_state)) +
106                       ' initial states.')
107
108    if self.go_backwards:
109      # Reverse time axis.
110      inputs = K.reverse(inputs, 1)
111    output, states = self._process_batch(inputs, initial_state)
112
113    if self.stateful:
114      updates = []
115      for i in range(len(states)):
116        updates.append(state_ops.assign(self.states[i], states[i]))
117      self.add_update(updates, inputs)
118
119    if self.return_state:
120      return [output] + states
121    else:
122      return output
123
124  def get_config(self):
125    config = {
126        'return_sequences': self.return_sequences,
127        'return_state': self.return_state,
128        'go_backwards': self.go_backwards,
129        'stateful': self.stateful,
130        'time_major': self.time_major,
131    }
132    base_config = super(  # pylint: disable=bad-super-call
133        RNN, self).get_config()
134    return dict(list(base_config.items()) + list(config.items()))
135
136  @classmethod
137  def from_config(cls, config):
138    return cls(**config)
139
140  @property
141  def trainable_weights(self):
142    if self.trainable and self.built:
143      return [self.kernel, self.recurrent_kernel, self.bias]
144    return []
145
146  @property
147  def non_trainable_weights(self):
148    if not self.trainable and self.built:
149      return [self.kernel, self.recurrent_kernel, self.bias]
150    return []
151
152  @property
153  def losses(self):
154    return super(RNN, self).losses
155
156  def get_losses_for(self, inputs=None):
157    return super(  # pylint: disable=bad-super-call
158        RNN, self).get_losses_for(inputs=inputs)
159
160
161@keras_export(v1=['keras.layers.CuDNNGRU'])
162class CuDNNGRU(_CuDNNRNN):
163  """Fast GRU implementation backed by cuDNN.
164
165  More information about cuDNN can be found on the [NVIDIA
166  developer website](https://developer.nvidia.com/cudnn).
167  Can only be run on GPU.
168
169  Arguments:
170      units: Positive integer, dimensionality of the output space.
171      kernel_initializer: Initializer for the `kernel` weights matrix, used for
172        the linear transformation of the inputs.
173      recurrent_initializer: Initializer for the `recurrent_kernel` weights
174        matrix, used for the linear transformation of the recurrent state.
175      bias_initializer: Initializer for the bias vector.
176      kernel_regularizer: Regularizer function applied to the `kernel` weights
177        matrix.
178      recurrent_regularizer: Regularizer function applied to the
179        `recurrent_kernel` weights matrix.
180      bias_regularizer: Regularizer function applied to the bias vector.
181      activity_regularizer: Regularizer function applied to the output of the
182        layer (its "activation").
183      kernel_constraint: Constraint function applied to the `kernel` weights
184        matrix.
185      recurrent_constraint: Constraint function applied to the
186        `recurrent_kernel` weights matrix.
187      bias_constraint: Constraint function applied to the bias vector.
188      return_sequences: Boolean. Whether to return the last output in the output
189        sequence, or the full sequence.
190      return_state: Boolean. Whether to return the last state in addition to the
191        output.
192      go_backwards: Boolean (default False). If True, process the input sequence
193        backwards and return the reversed sequence.
194      stateful: Boolean (default False). If True, the last state for each sample
195        at index i in a batch will be used as initial state for the sample of
196        index i in the following batch.
197  """
198
199  def __init__(self,
200               units,
201               kernel_initializer='glorot_uniform',
202               recurrent_initializer='orthogonal',
203               bias_initializer='zeros',
204               kernel_regularizer=None,
205               recurrent_regularizer=None,
206               bias_regularizer=None,
207               activity_regularizer=None,
208               kernel_constraint=None,
209               recurrent_constraint=None,
210               bias_constraint=None,
211               return_sequences=False,
212               return_state=False,
213               go_backwards=False,
214               stateful=False,
215               **kwargs):
216    self.units = units
217    cell_spec = collections.namedtuple('cell', 'state_size')
218    self._cell = cell_spec(state_size=self.units)
219    super(CuDNNGRU, self).__init__(
220        return_sequences=return_sequences,
221        return_state=return_state,
222        go_backwards=go_backwards,
223        stateful=stateful,
224        **kwargs)
225
226    self.kernel_initializer = initializers.get(kernel_initializer)
227    self.recurrent_initializer = initializers.get(recurrent_initializer)
228    self.bias_initializer = initializers.get(bias_initializer)
229
230    self.kernel_regularizer = regularizers.get(kernel_regularizer)
231    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
232    self.bias_regularizer = regularizers.get(bias_regularizer)
233    self.activity_regularizer = regularizers.get(activity_regularizer)
234
235    self.kernel_constraint = constraints.get(kernel_constraint)
236    self.recurrent_constraint = constraints.get(recurrent_constraint)
237    self.bias_constraint = constraints.get(bias_constraint)
238
239  @property
240  def cell(self):
241    return self._cell
242
243  def build(self, input_shape):
244    super(CuDNNGRU, self).build(input_shape)
245    if isinstance(input_shape, list):
246      input_shape = input_shape[0]
247    input_dim = int(input_shape[-1])
248
249    self.kernel = self.add_weight(
250        shape=(input_dim, self.units * 3),
251        name='kernel',
252        initializer=self.kernel_initializer,
253        regularizer=self.kernel_regularizer,
254        constraint=self.kernel_constraint)
255
256    self.recurrent_kernel = self.add_weight(
257        shape=(self.units, self.units * 3),
258        name='recurrent_kernel',
259        initializer=self.recurrent_initializer,
260        regularizer=self.recurrent_regularizer,
261        constraint=self.recurrent_constraint)
262
263    self.bias = self.add_weight(
264        shape=(self.units * 6,),
265        name='bias',
266        initializer=self.bias_initializer,
267        regularizer=self.bias_regularizer,
268        constraint=self.bias_constraint)
269
270    self.built = True
271
272  def _process_batch(self, inputs, initial_state):
273    if not self.time_major:
274      inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
275    input_h = initial_state[0]
276    input_h = array_ops.expand_dims(input_h, axis=0)
277
278    params = recurrent_v2._canonical_to_params(    # pylint: disable=protected-access
279        weights=[
280            self.kernel[:, self.units:self.units * 2],
281            self.kernel[:, :self.units],
282            self.kernel[:, self.units * 2:],
283            self.recurrent_kernel[:, self.units:self.units * 2],
284            self.recurrent_kernel[:, :self.units],
285            self.recurrent_kernel[:, self.units * 2:],
286        ],
287        biases=[
288            self.bias[self.units:self.units * 2],
289            self.bias[:self.units],
290            self.bias[self.units * 2:self.units * 3],
291            self.bias[self.units * 4:self.units * 5],
292            self.bias[self.units * 3:self.units * 4],
293            self.bias[self.units * 5:],
294        ],
295        shape=self._vector_shape)
296
297    outputs, h, _, _ = gen_cudnn_rnn_ops.cudnn_rnn(
298        inputs,
299        input_h=input_h,
300        input_c=0,
301        params=params,
302        is_training=True,
303        rnn_mode='gru')
304
305    if self.stateful or self.return_state:
306      h = h[0]
307    if self.return_sequences:
308      if self.time_major:
309        output = outputs
310      else:
311        output = array_ops.transpose(outputs, perm=(1, 0, 2))
312    else:
313      output = outputs[-1]
314    return output, [h]
315
316  def get_config(self):
317    config = {
318        'units': self.units,
319        'kernel_initializer': initializers.serialize(self.kernel_initializer),
320        'recurrent_initializer':
321            initializers.serialize(self.recurrent_initializer),
322        'bias_initializer': initializers.serialize(self.bias_initializer),
323        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
324        'recurrent_regularizer':
325            regularizers.serialize(self.recurrent_regularizer),
326        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
327        'activity_regularizer':
328            regularizers.serialize(self.activity_regularizer),
329        'kernel_constraint': constraints.serialize(self.kernel_constraint),
330        'recurrent_constraint':
331            constraints.serialize(self.recurrent_constraint),
332        'bias_constraint': constraints.serialize(self.bias_constraint)
333    }
334    base_config = super(CuDNNGRU, self).get_config()
335    return dict(list(base_config.items()) + list(config.items()))
336
337
338@keras_export(v1=['keras.layers.CuDNNLSTM'])
339class CuDNNLSTM(_CuDNNRNN):
340  """Fast LSTM implementation backed by cuDNN.
341
342  More information about cuDNN can be found on the [NVIDIA
343  developer website](https://developer.nvidia.com/cudnn).
344  Can only be run on GPU.
345
346  Arguments:
347      units: Positive integer, dimensionality of the output space.
348      kernel_initializer: Initializer for the `kernel` weights matrix, used for
349        the linear transformation of the inputs.
350      unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate
351        at initialization. Setting it to true will also force
352        `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
353        al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
354      recurrent_initializer: Initializer for the `recurrent_kernel` weights
355        matrix, used for the linear transformation of the recurrent state.
356      bias_initializer: Initializer for the bias vector.
357      kernel_regularizer: Regularizer function applied to the `kernel` weights
358        matrix.
359      recurrent_regularizer: Regularizer function applied to the
360        `recurrent_kernel` weights matrix.
361      bias_regularizer: Regularizer function applied to the bias vector.
362      activity_regularizer: Regularizer function applied to the output of the
363        layer (its "activation").
364      kernel_constraint: Constraint function applied to the `kernel` weights
365        matrix.
366      recurrent_constraint: Constraint function applied to the
367        `recurrent_kernel` weights matrix.
368      bias_constraint: Constraint function applied to the bias vector.
369      return_sequences: Boolean. Whether to return the last output. in the
370        output sequence, or the full sequence.
371      return_state: Boolean. Whether to return the last state in addition to the
372        output.
373      go_backwards: Boolean (default False). If True, process the input sequence
374        backwards and return the reversed sequence.
375      stateful: Boolean (default False). If True, the last state for each sample
376        at index i in a batch will be used as initial state for the sample of
377        index i in the following batch.
378  """
379
380  def __init__(self,
381               units,
382               kernel_initializer='glorot_uniform',
383               recurrent_initializer='orthogonal',
384               bias_initializer='zeros',
385               unit_forget_bias=True,
386               kernel_regularizer=None,
387               recurrent_regularizer=None,
388               bias_regularizer=None,
389               activity_regularizer=None,
390               kernel_constraint=None,
391               recurrent_constraint=None,
392               bias_constraint=None,
393               return_sequences=False,
394               return_state=False,
395               go_backwards=False,
396               stateful=False,
397               **kwargs):
398    self.units = units
399    cell_spec = collections.namedtuple('cell', 'state_size')
400    self._cell = cell_spec(state_size=(self.units, self.units))
401    super(CuDNNLSTM, self).__init__(
402        return_sequences=return_sequences,
403        return_state=return_state,
404        go_backwards=go_backwards,
405        stateful=stateful,
406        **kwargs)
407
408    self.kernel_initializer = initializers.get(kernel_initializer)
409    self.recurrent_initializer = initializers.get(recurrent_initializer)
410    self.bias_initializer = initializers.get(bias_initializer)
411    self.unit_forget_bias = unit_forget_bias
412
413    self.kernel_regularizer = regularizers.get(kernel_regularizer)
414    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
415    self.bias_regularizer = regularizers.get(bias_regularizer)
416    self.activity_regularizer = regularizers.get(activity_regularizer)
417
418    self.kernel_constraint = constraints.get(kernel_constraint)
419    self.recurrent_constraint = constraints.get(recurrent_constraint)
420    self.bias_constraint = constraints.get(bias_constraint)
421
422  @property
423  def cell(self):
424    return self._cell
425
426  def build(self, input_shape):
427    super(CuDNNLSTM, self).build(input_shape)
428    if isinstance(input_shape, list):
429      input_shape = input_shape[0]
430    input_dim = int(input_shape[-1])
431
432    self.kernel = self.add_weight(
433        shape=(input_dim, self.units * 4),
434        name='kernel',
435        initializer=self.kernel_initializer,
436        regularizer=self.kernel_regularizer,
437        constraint=self.kernel_constraint)
438
439    self.recurrent_kernel = self.add_weight(
440        shape=(self.units, self.units * 4),
441        name='recurrent_kernel',
442        initializer=self.recurrent_initializer,
443        regularizer=self.recurrent_regularizer,
444        constraint=self.recurrent_constraint)
445
446    if self.unit_forget_bias:
447
448      def bias_initializer(_, *args, **kwargs):
449        return array_ops.concat([
450            self.bias_initializer((self.units * 5,), *args, **kwargs),
451            initializers.Ones()((self.units,), *args, **kwargs),
452            self.bias_initializer((self.units * 2,), *args, **kwargs),
453        ], axis=0)
454    else:
455      bias_initializer = self.bias_initializer
456    self.bias = self.add_weight(
457        shape=(self.units * 8,),
458        name='bias',
459        initializer=bias_initializer,
460        regularizer=self.bias_regularizer,
461        constraint=self.bias_constraint)
462
463    self.built = True
464
465  def _process_batch(self, inputs, initial_state):
466    if not self.time_major:
467      inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
468    input_h = initial_state[0]
469    input_c = initial_state[1]
470    input_h = array_ops.expand_dims(input_h, axis=0)
471    input_c = array_ops.expand_dims(input_c, axis=0)
472
473    params = recurrent_v2._canonical_to_params(    # pylint: disable=protected-access
474        weights=[
475            self.kernel[:, :self.units],
476            self.kernel[:, self.units:self.units * 2],
477            self.kernel[:, self.units * 2:self.units * 3],
478            self.kernel[:, self.units * 3:],
479            self.recurrent_kernel[:, :self.units],
480            self.recurrent_kernel[:, self.units:self.units * 2],
481            self.recurrent_kernel[:, self.units * 2:self.units * 3],
482            self.recurrent_kernel[:, self.units * 3:],
483        ],
484        biases=[
485            self.bias[:self.units],
486            self.bias[self.units:self.units * 2],
487            self.bias[self.units * 2:self.units * 3],
488            self.bias[self.units * 3:self.units * 4],
489            self.bias[self.units * 4:self.units * 5],
490            self.bias[self.units * 5:self.units * 6],
491            self.bias[self.units * 6:self.units * 7],
492            self.bias[self.units * 7:],
493        ],
494        shape=self._vector_shape)
495
496    outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
497        inputs,
498        input_h=input_h,
499        input_c=input_c,
500        params=params,
501        is_training=True)
502
503    if self.stateful or self.return_state:
504      h = h[0]
505      c = c[0]
506    if self.return_sequences:
507      if self.time_major:
508        output = outputs
509      else:
510        output = array_ops.transpose(outputs, perm=(1, 0, 2))
511    else:
512      output = outputs[-1]
513    return output, [h, c]
514
515  def get_config(self):
516    config = {
517        'units': self.units,
518        'kernel_initializer': initializers.serialize(self.kernel_initializer),
519        'recurrent_initializer':
520            initializers.serialize(self.recurrent_initializer),
521        'bias_initializer': initializers.serialize(self.bias_initializer),
522        'unit_forget_bias': self.unit_forget_bias,
523        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
524        'recurrent_regularizer':
525            regularizers.serialize(self.recurrent_regularizer),
526        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
527        'activity_regularizer':
528            regularizers.serialize(self.activity_regularizer),
529        'kernel_constraint': constraints.serialize(self.kernel_constraint),
530        'recurrent_constraint':
531            constraints.serialize(self.recurrent_constraint),
532        'bias_constraint': constraints.serialize(self.bias_constraint)
533    }
534    base_config = super(CuDNNLSTM, self).get_config()
535    return dict(list(base_config.items()) + list(config.items()))
536