1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Recurrent layers backed by cuDNN.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.keras import backend as K
25from tensorflow.python.keras import constraints
26from tensorflow.python.keras import initializers
27from tensorflow.python.keras import regularizers
28from tensorflow.python.keras.engine.input_spec import InputSpec
29from tensorflow.python.keras.layers import recurrent_v2
30from tensorflow.python.keras.layers.recurrent import RNN
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import gen_cudnn_rnn_ops
33from tensorflow.python.ops import state_ops
34from tensorflow.python.util.tf_export import keras_export
35
36
37class _CuDNNRNN(RNN):
38  """Private base class for CuDNNGRU and CuDNNLSTM layers.
39
40  Args:
41    return_sequences: Boolean. Whether to return the last output
42        in the output sequence, or the full sequence.
43    return_state: Boolean. Whether to return the last state
44        in addition to the output.
45    go_backwards: Boolean (default False).
46        If True, process the input sequence backwards and return the
47        reversed sequence.
48    stateful: Boolean (default False). If True, the last state
49        for each sample at index i in a batch will be used as initial
50        state for the sample of index i in the following batch.
51    time_major: Boolean (default False). If true, the inputs and outputs will be
52        in shape `(timesteps, batch, ...)`, whereas in the False case, it will
53        be `(batch, timesteps, ...)`.
54  """
55
56  def __init__(self,
57               return_sequences=False,
58               return_state=False,
59               go_backwards=False,
60               stateful=False,
61               time_major=False,
62               **kwargs):
63    # We invoke the base layer's initializer directly here because we do not
64    # want to create RNN cell instance.
65    super(RNN, self).__init__(**kwargs)  # pylint: disable=bad-super-call
66    self.return_sequences = return_sequences
67    self.return_state = return_state
68    self.go_backwards = go_backwards
69    self.stateful = stateful
70    self.time_major = time_major
71    self.supports_masking = False
72    self.input_spec = [InputSpec(ndim=3)]
73    if hasattr(self.cell.state_size, '__len__'):
74      state_size = self.cell.state_size
75    else:
76      state_size = [self.cell.state_size]
77    self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
78    self.constants_spec = None
79    self._states = None
80    self._num_constants = 0
81    self._vector_shape = constant_op.constant([-1])
82
83  def call(self, inputs, mask=None, training=None, initial_state=None):
84    if isinstance(mask, list):
85      mask = mask[0]
86    if mask is not None:
87      raise ValueError('Masking is not supported for CuDNN RNNs.')
88
89    # input shape: `(samples, time (padded with zeros), input_dim)`
90    # note that the .build() method of subclasses MUST define
91    # self.input_spec and self.state_spec with complete input shapes.
92    if isinstance(inputs, list):
93      initial_state = inputs[1:]
94      inputs = inputs[0]
95    elif initial_state is not None:
96      pass
97    elif self.stateful:
98      initial_state = self.states
99    else:
100      initial_state = self.get_initial_state(inputs)
101
102    if len(initial_state) != len(self.states):
103      raise ValueError('Layer has ' + str(len(self.states)) +
104                       ' states but was passed ' + str(len(initial_state)) +
105                       ' initial states.')
106
107    if self.go_backwards:
108      # Reverse time axis.
109      inputs = K.reverse(inputs, 1)
110    output, states = self._process_batch(inputs, initial_state)
111
112    if self.stateful:
113      updates = [
114          state_ops.assign(self_state, state)
115          for self_state, state in zip(self.states, states)
116      ]
117      self.add_update(updates)
118
119    if self.return_state:
120      return [output] + states
121    else:
122      return output
123
124  def get_config(self):
125    config = {
126        'return_sequences': self.return_sequences,
127        'return_state': self.return_state,
128        'go_backwards': self.go_backwards,
129        'stateful': self.stateful,
130        'time_major': self.time_major,
131    }
132    base_config = super(  # pylint: disable=bad-super-call
133        RNN, self).get_config()
134    return dict(list(base_config.items()) + list(config.items()))
135
136  @classmethod
137  def from_config(cls, config):
138    return cls(**config)
139
140  @property
141  def trainable_weights(self):
142    if self.trainable and self.built:
143      return [self.kernel, self.recurrent_kernel, self.bias]
144    return []
145
146  @property
147  def non_trainable_weights(self):
148    if not self.trainable and self.built:
149      return [self.kernel, self.recurrent_kernel, self.bias]
150    return []
151
152  @property
153  def losses(self):
154    return super(RNN, self).losses
155
156  def get_losses_for(self, inputs=None):
157    return super(  # pylint: disable=bad-super-call
158        RNN, self).get_losses_for(inputs=inputs)
159
160
161@keras_export(v1=['keras.layers.CuDNNGRU'])
162class CuDNNGRU(_CuDNNRNN):
163  """Fast GRU implementation backed by cuDNN.
164
165  More information about cuDNN can be found on the [NVIDIA
166  developer website](https://developer.nvidia.com/cudnn).
167  Can only be run on GPU.
168
169  Args:
170      units: Positive integer, dimensionality of the output space.
171      kernel_initializer: Initializer for the `kernel` weights matrix, used for
172        the linear transformation of the inputs.
173      recurrent_initializer: Initializer for the `recurrent_kernel` weights
174        matrix, used for the linear transformation of the recurrent state.
175      bias_initializer: Initializer for the bias vector.
176      kernel_regularizer: Regularizer function applied to the `kernel` weights
177        matrix.
178      recurrent_regularizer: Regularizer function applied to the
179        `recurrent_kernel` weights matrix.
180      bias_regularizer: Regularizer function applied to the bias vector.
181      activity_regularizer: Regularizer function applied to the output of the
182        layer (its "activation").
183      kernel_constraint: Constraint function applied to the `kernel` weights
184        matrix.
185      recurrent_constraint: Constraint function applied to the
186        `recurrent_kernel` weights matrix.
187      bias_constraint: Constraint function applied to the bias vector.
188      return_sequences: Boolean. Whether to return the last output in the output
189        sequence, or the full sequence.
190      return_state: Boolean. Whether to return the last state in addition to the
191        output.
192      go_backwards: Boolean (default False). If True, process the input sequence
193        backwards and return the reversed sequence.
194      stateful: Boolean (default False). If True, the last state for each sample
195        at index i in a batch will be used as initial state for the sample of
196        index i in the following batch.
197  """
198
199  def __init__(self,
200               units,
201               kernel_initializer='glorot_uniform',
202               recurrent_initializer='orthogonal',
203               bias_initializer='zeros',
204               kernel_regularizer=None,
205               recurrent_regularizer=None,
206               bias_regularizer=None,
207               activity_regularizer=None,
208               kernel_constraint=None,
209               recurrent_constraint=None,
210               bias_constraint=None,
211               return_sequences=False,
212               return_state=False,
213               go_backwards=False,
214               stateful=False,
215               **kwargs):
216    self.units = units
217    cell_spec = collections.namedtuple('cell', 'state_size')
218    self._cell = cell_spec(state_size=self.units)
219    super(CuDNNGRU, self).__init__(
220        return_sequences=return_sequences,
221        return_state=return_state,
222        go_backwards=go_backwards,
223        stateful=stateful,
224        **kwargs)
225
226    self.kernel_initializer = initializers.get(kernel_initializer)
227    self.recurrent_initializer = initializers.get(recurrent_initializer)
228    self.bias_initializer = initializers.get(bias_initializer)
229
230    self.kernel_regularizer = regularizers.get(kernel_regularizer)
231    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
232    self.bias_regularizer = regularizers.get(bias_regularizer)
233    self.activity_regularizer = regularizers.get(activity_regularizer)
234
235    self.kernel_constraint = constraints.get(kernel_constraint)
236    self.recurrent_constraint = constraints.get(recurrent_constraint)
237    self.bias_constraint = constraints.get(bias_constraint)
238
239  @property
240  def cell(self):
241    return self._cell
242
243  def build(self, input_shape):
244    super(CuDNNGRU, self).build(input_shape)
245    if isinstance(input_shape, list):
246      input_shape = input_shape[0]
247    input_dim = int(input_shape[-1])
248
249    self.kernel = self.add_weight(
250        shape=(input_dim, self.units * 3),
251        name='kernel',
252        initializer=self.kernel_initializer,
253        regularizer=self.kernel_regularizer,
254        constraint=self.kernel_constraint)
255
256    self.recurrent_kernel = self.add_weight(
257        shape=(self.units, self.units * 3),
258        name='recurrent_kernel',
259        initializer=self.recurrent_initializer,
260        regularizer=self.recurrent_regularizer,
261        constraint=self.recurrent_constraint)
262
263    self.bias = self.add_weight(
264        shape=(self.units * 6,),
265        name='bias',
266        initializer=self.bias_initializer,
267        regularizer=self.bias_regularizer,
268        constraint=self.bias_constraint)
269
270    self.built = True
271
272  def _process_batch(self, inputs, initial_state):
273    if not self.time_major:
274      inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
275    input_h = initial_state[0]
276    input_h = array_ops.expand_dims(input_h, axis=0)
277
278    params = recurrent_v2._canonical_to_params(    # pylint: disable=protected-access
279        weights=[
280            self.kernel[:, self.units:self.units * 2],
281            self.kernel[:, :self.units],
282            self.kernel[:, self.units * 2:],
283            self.recurrent_kernel[:, self.units:self.units * 2],
284            self.recurrent_kernel[:, :self.units],
285            self.recurrent_kernel[:, self.units * 2:],
286        ],
287        biases=[
288            self.bias[self.units:self.units * 2],
289            self.bias[:self.units],
290            self.bias[self.units * 2:self.units * 3],
291            self.bias[self.units * 4:self.units * 5],
292            self.bias[self.units * 3:self.units * 4],
293            self.bias[self.units * 5:],
294        ],
295        shape=self._vector_shape)
296
297    args = {
298        'input': inputs,
299        'input_h': input_h,
300        'input_c': 0,
301        'params': params,
302        'is_training': True,
303        'rnn_mode': 'gru',
304    }
305
306    outputs, h, _, _, _ = gen_cudnn_rnn_ops.CudnnRNNV2(**args)
307
308    if self.stateful or self.return_state:
309      h = h[0]
310    if self.return_sequences:
311      if self.time_major:
312        output = outputs
313      else:
314        output = array_ops.transpose(outputs, perm=(1, 0, 2))
315    else:
316      output = outputs[-1]
317    return output, [h]
318
319  def get_config(self):
320    config = {
321        'units': self.units,
322        'kernel_initializer': initializers.serialize(self.kernel_initializer),
323        'recurrent_initializer':
324            initializers.serialize(self.recurrent_initializer),
325        'bias_initializer': initializers.serialize(self.bias_initializer),
326        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
327        'recurrent_regularizer':
328            regularizers.serialize(self.recurrent_regularizer),
329        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
330        'activity_regularizer':
331            regularizers.serialize(self.activity_regularizer),
332        'kernel_constraint': constraints.serialize(self.kernel_constraint),
333        'recurrent_constraint':
334            constraints.serialize(self.recurrent_constraint),
335        'bias_constraint': constraints.serialize(self.bias_constraint)
336    }
337    base_config = super(CuDNNGRU, self).get_config()
338    return dict(list(base_config.items()) + list(config.items()))
339
340
341@keras_export(v1=['keras.layers.CuDNNLSTM'])
342class CuDNNLSTM(_CuDNNRNN):
343  """Fast LSTM implementation backed by cuDNN.
344
345  More information about cuDNN can be found on the [NVIDIA
346  developer website](https://developer.nvidia.com/cudnn).
347  Can only be run on GPU.
348
349  Args:
350      units: Positive integer, dimensionality of the output space.
351      kernel_initializer: Initializer for the `kernel` weights matrix, used for
352        the linear transformation of the inputs.
353      unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate
354        at initialization. Setting it to true will also force
355        `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
356        al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
357      recurrent_initializer: Initializer for the `recurrent_kernel` weights
358        matrix, used for the linear transformation of the recurrent state.
359      bias_initializer: Initializer for the bias vector.
360      kernel_regularizer: Regularizer function applied to the `kernel` weights
361        matrix.
362      recurrent_regularizer: Regularizer function applied to the
363        `recurrent_kernel` weights matrix.
364      bias_regularizer: Regularizer function applied to the bias vector.
365      activity_regularizer: Regularizer function applied to the output of the
366        layer (its "activation").
367      kernel_constraint: Constraint function applied to the `kernel` weights
368        matrix.
369      recurrent_constraint: Constraint function applied to the
370        `recurrent_kernel` weights matrix.
371      bias_constraint: Constraint function applied to the bias vector.
372      return_sequences: Boolean. Whether to return the last output. in the
373        output sequence, or the full sequence.
374      return_state: Boolean. Whether to return the last state in addition to the
375        output.
376      go_backwards: Boolean (default False). If True, process the input sequence
377        backwards and return the reversed sequence.
378      stateful: Boolean (default False). If True, the last state for each sample
379        at index i in a batch will be used as initial state for the sample of
380        index i in the following batch.
381  """
382
383  def __init__(self,
384               units,
385               kernel_initializer='glorot_uniform',
386               recurrent_initializer='orthogonal',
387               bias_initializer='zeros',
388               unit_forget_bias=True,
389               kernel_regularizer=None,
390               recurrent_regularizer=None,
391               bias_regularizer=None,
392               activity_regularizer=None,
393               kernel_constraint=None,
394               recurrent_constraint=None,
395               bias_constraint=None,
396               return_sequences=False,
397               return_state=False,
398               go_backwards=False,
399               stateful=False,
400               **kwargs):
401    self.units = units
402    cell_spec = collections.namedtuple('cell', 'state_size')
403    self._cell = cell_spec(state_size=(self.units, self.units))
404    super(CuDNNLSTM, self).__init__(
405        return_sequences=return_sequences,
406        return_state=return_state,
407        go_backwards=go_backwards,
408        stateful=stateful,
409        **kwargs)
410
411    self.kernel_initializer = initializers.get(kernel_initializer)
412    self.recurrent_initializer = initializers.get(recurrent_initializer)
413    self.bias_initializer = initializers.get(bias_initializer)
414    self.unit_forget_bias = unit_forget_bias
415
416    self.kernel_regularizer = regularizers.get(kernel_regularizer)
417    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
418    self.bias_regularizer = regularizers.get(bias_regularizer)
419    self.activity_regularizer = regularizers.get(activity_regularizer)
420
421    self.kernel_constraint = constraints.get(kernel_constraint)
422    self.recurrent_constraint = constraints.get(recurrent_constraint)
423    self.bias_constraint = constraints.get(bias_constraint)
424
425  @property
426  def cell(self):
427    return self._cell
428
429  def build(self, input_shape):
430    super(CuDNNLSTM, self).build(input_shape)
431    if isinstance(input_shape, list):
432      input_shape = input_shape[0]
433    input_dim = int(input_shape[-1])
434
435    self.kernel = self.add_weight(
436        shape=(input_dim, self.units * 4),
437        name='kernel',
438        initializer=self.kernel_initializer,
439        regularizer=self.kernel_regularizer,
440        constraint=self.kernel_constraint)
441
442    self.recurrent_kernel = self.add_weight(
443        shape=(self.units, self.units * 4),
444        name='recurrent_kernel',
445        initializer=self.recurrent_initializer,
446        regularizer=self.recurrent_regularizer,
447        constraint=self.recurrent_constraint)
448
449    if self.unit_forget_bias:
450
451      def bias_initializer(_, *args, **kwargs):
452        return array_ops.concat([
453            self.bias_initializer((self.units * 5,), *args, **kwargs),
454            initializers.Ones()((self.units,), *args, **kwargs),
455            self.bias_initializer((self.units * 2,), *args, **kwargs),
456        ], axis=0)
457    else:
458      bias_initializer = self.bias_initializer
459    self.bias = self.add_weight(
460        shape=(self.units * 8,),
461        name='bias',
462        initializer=bias_initializer,
463        regularizer=self.bias_regularizer,
464        constraint=self.bias_constraint)
465
466    self.built = True
467
468  def _process_batch(self, inputs, initial_state):
469    if not self.time_major:
470      inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
471    input_h = initial_state[0]
472    input_c = initial_state[1]
473    input_h = array_ops.expand_dims(input_h, axis=0)
474    input_c = array_ops.expand_dims(input_c, axis=0)
475
476    params = recurrent_v2._canonical_to_params(    # pylint: disable=protected-access
477        weights=[
478            self.kernel[:, :self.units],
479            self.kernel[:, self.units:self.units * 2],
480            self.kernel[:, self.units * 2:self.units * 3],
481            self.kernel[:, self.units * 3:],
482            self.recurrent_kernel[:, :self.units],
483            self.recurrent_kernel[:, self.units:self.units * 2],
484            self.recurrent_kernel[:, self.units * 2:self.units * 3],
485            self.recurrent_kernel[:, self.units * 3:],
486        ],
487        biases=[
488            self.bias[:self.units],
489            self.bias[self.units:self.units * 2],
490            self.bias[self.units * 2:self.units * 3],
491            self.bias[self.units * 3:self.units * 4],
492            self.bias[self.units * 4:self.units * 5],
493            self.bias[self.units * 5:self.units * 6],
494            self.bias[self.units * 6:self.units * 7],
495            self.bias[self.units * 7:],
496        ],
497        shape=self._vector_shape)
498
499    args = {
500        'input': inputs,
501        'input_h': input_h,
502        'input_c': input_c,
503        'params': params,
504        'is_training': True,
505    }
506
507    outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV2(**args)
508
509    if self.stateful or self.return_state:
510      h = h[0]
511      c = c[0]
512    if self.return_sequences:
513      if self.time_major:
514        output = outputs
515      else:
516        output = array_ops.transpose(outputs, perm=(1, 0, 2))
517    else:
518      output = outputs[-1]
519    return output, [h, c]
520
521  def get_config(self):
522    config = {
523        'units': self.units,
524        'kernel_initializer': initializers.serialize(self.kernel_initializer),
525        'recurrent_initializer':
526            initializers.serialize(self.recurrent_initializer),
527        'bias_initializer': initializers.serialize(self.bias_initializer),
528        'unit_forget_bias': self.unit_forget_bias,
529        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
530        'recurrent_regularizer':
531            regularizers.serialize(self.recurrent_regularizer),
532        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
533        'activity_regularizer':
534            regularizers.serialize(self.activity_regularizer),
535        'kernel_constraint': constraints.serialize(self.kernel_constraint),
536        'recurrent_constraint':
537            constraints.serialize(self.recurrent_constraint),
538        'bias_constraint': constraints.serialize(self.bias_constraint)
539    }
540    base_config = super(CuDNNLSTM, self).get_config()
541    return dict(list(base_config.items()) + list(config.items()))
542