1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Recurrent layers backed by cuDNN. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.keras import backend as K 25from tensorflow.python.keras import constraints 26from tensorflow.python.keras import initializers 27from tensorflow.python.keras import regularizers 28from tensorflow.python.keras.engine.input_spec import InputSpec 29from tensorflow.python.keras.layers import recurrent_v2 30from tensorflow.python.keras.layers.recurrent import RNN 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import gen_cudnn_rnn_ops 33from tensorflow.python.ops import state_ops 34from tensorflow.python.util.tf_export import keras_export 35 36 37class _CuDNNRNN(RNN): 38 """Private base class for CuDNNGRU and CuDNNLSTM layers. 39 40 Args: 41 return_sequences: Boolean. Whether to return the last output 42 in the output sequence, or the full sequence. 43 return_state: Boolean. Whether to return the last state 44 in addition to the output. 45 go_backwards: Boolean (default False). 46 If True, process the input sequence backwards and return the 47 reversed sequence. 48 stateful: Boolean (default False). If True, the last state 49 for each sample at index i in a batch will be used as initial 50 state for the sample of index i in the following batch. 51 time_major: Boolean (default False). If true, the inputs and outputs will be 52 in shape `(timesteps, batch, ...)`, whereas in the False case, it will 53 be `(batch, timesteps, ...)`. 54 """ 55 56 def __init__(self, 57 return_sequences=False, 58 return_state=False, 59 go_backwards=False, 60 stateful=False, 61 time_major=False, 62 **kwargs): 63 # We invoke the base layer's initializer directly here because we do not 64 # want to create RNN cell instance. 65 super(RNN, self).__init__(**kwargs) # pylint: disable=bad-super-call 66 self.return_sequences = return_sequences 67 self.return_state = return_state 68 self.go_backwards = go_backwards 69 self.stateful = stateful 70 self.time_major = time_major 71 self.supports_masking = False 72 self.input_spec = [InputSpec(ndim=3)] 73 if hasattr(self.cell.state_size, '__len__'): 74 state_size = self.cell.state_size 75 else: 76 state_size = [self.cell.state_size] 77 self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size] 78 self.constants_spec = None 79 self._states = None 80 self._num_constants = 0 81 self._vector_shape = constant_op.constant([-1]) 82 83 def call(self, inputs, mask=None, training=None, initial_state=None): 84 if isinstance(mask, list): 85 mask = mask[0] 86 if mask is not None: 87 raise ValueError('Masking is not supported for CuDNN RNNs.') 88 89 # input shape: `(samples, time (padded with zeros), input_dim)` 90 # note that the .build() method of subclasses MUST define 91 # self.input_spec and self.state_spec with complete input shapes. 92 if isinstance(inputs, list): 93 initial_state = inputs[1:] 94 inputs = inputs[0] 95 elif initial_state is not None: 96 pass 97 elif self.stateful: 98 initial_state = self.states 99 else: 100 initial_state = self.get_initial_state(inputs) 101 102 if len(initial_state) != len(self.states): 103 raise ValueError('Layer has ' + str(len(self.states)) + 104 ' states but was passed ' + str(len(initial_state)) + 105 ' initial states.') 106 107 if self.go_backwards: 108 # Reverse time axis. 109 inputs = K.reverse(inputs, 1) 110 output, states = self._process_batch(inputs, initial_state) 111 112 if self.stateful: 113 updates = [ 114 state_ops.assign(self_state, state) 115 for self_state, state in zip(self.states, states) 116 ] 117 self.add_update(updates) 118 119 if self.return_state: 120 return [output] + states 121 else: 122 return output 123 124 def get_config(self): 125 config = { 126 'return_sequences': self.return_sequences, 127 'return_state': self.return_state, 128 'go_backwards': self.go_backwards, 129 'stateful': self.stateful, 130 'time_major': self.time_major, 131 } 132 base_config = super( # pylint: disable=bad-super-call 133 RNN, self).get_config() 134 return dict(list(base_config.items()) + list(config.items())) 135 136 @classmethod 137 def from_config(cls, config): 138 return cls(**config) 139 140 @property 141 def trainable_weights(self): 142 if self.trainable and self.built: 143 return [self.kernel, self.recurrent_kernel, self.bias] 144 return [] 145 146 @property 147 def non_trainable_weights(self): 148 if not self.trainable and self.built: 149 return [self.kernel, self.recurrent_kernel, self.bias] 150 return [] 151 152 @property 153 def losses(self): 154 return super(RNN, self).losses 155 156 def get_losses_for(self, inputs=None): 157 return super( # pylint: disable=bad-super-call 158 RNN, self).get_losses_for(inputs=inputs) 159 160 161@keras_export(v1=['keras.layers.CuDNNGRU']) 162class CuDNNGRU(_CuDNNRNN): 163 """Fast GRU implementation backed by cuDNN. 164 165 More information about cuDNN can be found on the [NVIDIA 166 developer website](https://developer.nvidia.com/cudnn). 167 Can only be run on GPU. 168 169 Args: 170 units: Positive integer, dimensionality of the output space. 171 kernel_initializer: Initializer for the `kernel` weights matrix, used for 172 the linear transformation of the inputs. 173 recurrent_initializer: Initializer for the `recurrent_kernel` weights 174 matrix, used for the linear transformation of the recurrent state. 175 bias_initializer: Initializer for the bias vector. 176 kernel_regularizer: Regularizer function applied to the `kernel` weights 177 matrix. 178 recurrent_regularizer: Regularizer function applied to the 179 `recurrent_kernel` weights matrix. 180 bias_regularizer: Regularizer function applied to the bias vector. 181 activity_regularizer: Regularizer function applied to the output of the 182 layer (its "activation"). 183 kernel_constraint: Constraint function applied to the `kernel` weights 184 matrix. 185 recurrent_constraint: Constraint function applied to the 186 `recurrent_kernel` weights matrix. 187 bias_constraint: Constraint function applied to the bias vector. 188 return_sequences: Boolean. Whether to return the last output in the output 189 sequence, or the full sequence. 190 return_state: Boolean. Whether to return the last state in addition to the 191 output. 192 go_backwards: Boolean (default False). If True, process the input sequence 193 backwards and return the reversed sequence. 194 stateful: Boolean (default False). If True, the last state for each sample 195 at index i in a batch will be used as initial state for the sample of 196 index i in the following batch. 197 """ 198 199 def __init__(self, 200 units, 201 kernel_initializer='glorot_uniform', 202 recurrent_initializer='orthogonal', 203 bias_initializer='zeros', 204 kernel_regularizer=None, 205 recurrent_regularizer=None, 206 bias_regularizer=None, 207 activity_regularizer=None, 208 kernel_constraint=None, 209 recurrent_constraint=None, 210 bias_constraint=None, 211 return_sequences=False, 212 return_state=False, 213 go_backwards=False, 214 stateful=False, 215 **kwargs): 216 self.units = units 217 cell_spec = collections.namedtuple('cell', 'state_size') 218 self._cell = cell_spec(state_size=self.units) 219 super(CuDNNGRU, self).__init__( 220 return_sequences=return_sequences, 221 return_state=return_state, 222 go_backwards=go_backwards, 223 stateful=stateful, 224 **kwargs) 225 226 self.kernel_initializer = initializers.get(kernel_initializer) 227 self.recurrent_initializer = initializers.get(recurrent_initializer) 228 self.bias_initializer = initializers.get(bias_initializer) 229 230 self.kernel_regularizer = regularizers.get(kernel_regularizer) 231 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 232 self.bias_regularizer = regularizers.get(bias_regularizer) 233 self.activity_regularizer = regularizers.get(activity_regularizer) 234 235 self.kernel_constraint = constraints.get(kernel_constraint) 236 self.recurrent_constraint = constraints.get(recurrent_constraint) 237 self.bias_constraint = constraints.get(bias_constraint) 238 239 @property 240 def cell(self): 241 return self._cell 242 243 def build(self, input_shape): 244 super(CuDNNGRU, self).build(input_shape) 245 if isinstance(input_shape, list): 246 input_shape = input_shape[0] 247 input_dim = int(input_shape[-1]) 248 249 self.kernel = self.add_weight( 250 shape=(input_dim, self.units * 3), 251 name='kernel', 252 initializer=self.kernel_initializer, 253 regularizer=self.kernel_regularizer, 254 constraint=self.kernel_constraint) 255 256 self.recurrent_kernel = self.add_weight( 257 shape=(self.units, self.units * 3), 258 name='recurrent_kernel', 259 initializer=self.recurrent_initializer, 260 regularizer=self.recurrent_regularizer, 261 constraint=self.recurrent_constraint) 262 263 self.bias = self.add_weight( 264 shape=(self.units * 6,), 265 name='bias', 266 initializer=self.bias_initializer, 267 regularizer=self.bias_regularizer, 268 constraint=self.bias_constraint) 269 270 self.built = True 271 272 def _process_batch(self, inputs, initial_state): 273 if not self.time_major: 274 inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) 275 input_h = initial_state[0] 276 input_h = array_ops.expand_dims(input_h, axis=0) 277 278 params = recurrent_v2._canonical_to_params( # pylint: disable=protected-access 279 weights=[ 280 self.kernel[:, self.units:self.units * 2], 281 self.kernel[:, :self.units], 282 self.kernel[:, self.units * 2:], 283 self.recurrent_kernel[:, self.units:self.units * 2], 284 self.recurrent_kernel[:, :self.units], 285 self.recurrent_kernel[:, self.units * 2:], 286 ], 287 biases=[ 288 self.bias[self.units:self.units * 2], 289 self.bias[:self.units], 290 self.bias[self.units * 2:self.units * 3], 291 self.bias[self.units * 4:self.units * 5], 292 self.bias[self.units * 3:self.units * 4], 293 self.bias[self.units * 5:], 294 ], 295 shape=self._vector_shape) 296 297 args = { 298 'input': inputs, 299 'input_h': input_h, 300 'input_c': 0, 301 'params': params, 302 'is_training': True, 303 'rnn_mode': 'gru', 304 } 305 306 outputs, h, _, _, _ = gen_cudnn_rnn_ops.CudnnRNNV2(**args) 307 308 if self.stateful or self.return_state: 309 h = h[0] 310 if self.return_sequences: 311 if self.time_major: 312 output = outputs 313 else: 314 output = array_ops.transpose(outputs, perm=(1, 0, 2)) 315 else: 316 output = outputs[-1] 317 return output, [h] 318 319 def get_config(self): 320 config = { 321 'units': self.units, 322 'kernel_initializer': initializers.serialize(self.kernel_initializer), 323 'recurrent_initializer': 324 initializers.serialize(self.recurrent_initializer), 325 'bias_initializer': initializers.serialize(self.bias_initializer), 326 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 327 'recurrent_regularizer': 328 regularizers.serialize(self.recurrent_regularizer), 329 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 330 'activity_regularizer': 331 regularizers.serialize(self.activity_regularizer), 332 'kernel_constraint': constraints.serialize(self.kernel_constraint), 333 'recurrent_constraint': 334 constraints.serialize(self.recurrent_constraint), 335 'bias_constraint': constraints.serialize(self.bias_constraint) 336 } 337 base_config = super(CuDNNGRU, self).get_config() 338 return dict(list(base_config.items()) + list(config.items())) 339 340 341@keras_export(v1=['keras.layers.CuDNNLSTM']) 342class CuDNNLSTM(_CuDNNRNN): 343 """Fast LSTM implementation backed by cuDNN. 344 345 More information about cuDNN can be found on the [NVIDIA 346 developer website](https://developer.nvidia.com/cudnn). 347 Can only be run on GPU. 348 349 Args: 350 units: Positive integer, dimensionality of the output space. 351 kernel_initializer: Initializer for the `kernel` weights matrix, used for 352 the linear transformation of the inputs. 353 unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate 354 at initialization. Setting it to true will also force 355 `bias_initializer="zeros"`. This is recommended in [Jozefowicz et 356 al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 357 recurrent_initializer: Initializer for the `recurrent_kernel` weights 358 matrix, used for the linear transformation of the recurrent state. 359 bias_initializer: Initializer for the bias vector. 360 kernel_regularizer: Regularizer function applied to the `kernel` weights 361 matrix. 362 recurrent_regularizer: Regularizer function applied to the 363 `recurrent_kernel` weights matrix. 364 bias_regularizer: Regularizer function applied to the bias vector. 365 activity_regularizer: Regularizer function applied to the output of the 366 layer (its "activation"). 367 kernel_constraint: Constraint function applied to the `kernel` weights 368 matrix. 369 recurrent_constraint: Constraint function applied to the 370 `recurrent_kernel` weights matrix. 371 bias_constraint: Constraint function applied to the bias vector. 372 return_sequences: Boolean. Whether to return the last output. in the 373 output sequence, or the full sequence. 374 return_state: Boolean. Whether to return the last state in addition to the 375 output. 376 go_backwards: Boolean (default False). If True, process the input sequence 377 backwards and return the reversed sequence. 378 stateful: Boolean (default False). If True, the last state for each sample 379 at index i in a batch will be used as initial state for the sample of 380 index i in the following batch. 381 """ 382 383 def __init__(self, 384 units, 385 kernel_initializer='glorot_uniform', 386 recurrent_initializer='orthogonal', 387 bias_initializer='zeros', 388 unit_forget_bias=True, 389 kernel_regularizer=None, 390 recurrent_regularizer=None, 391 bias_regularizer=None, 392 activity_regularizer=None, 393 kernel_constraint=None, 394 recurrent_constraint=None, 395 bias_constraint=None, 396 return_sequences=False, 397 return_state=False, 398 go_backwards=False, 399 stateful=False, 400 **kwargs): 401 self.units = units 402 cell_spec = collections.namedtuple('cell', 'state_size') 403 self._cell = cell_spec(state_size=(self.units, self.units)) 404 super(CuDNNLSTM, self).__init__( 405 return_sequences=return_sequences, 406 return_state=return_state, 407 go_backwards=go_backwards, 408 stateful=stateful, 409 **kwargs) 410 411 self.kernel_initializer = initializers.get(kernel_initializer) 412 self.recurrent_initializer = initializers.get(recurrent_initializer) 413 self.bias_initializer = initializers.get(bias_initializer) 414 self.unit_forget_bias = unit_forget_bias 415 416 self.kernel_regularizer = regularizers.get(kernel_regularizer) 417 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 418 self.bias_regularizer = regularizers.get(bias_regularizer) 419 self.activity_regularizer = regularizers.get(activity_regularizer) 420 421 self.kernel_constraint = constraints.get(kernel_constraint) 422 self.recurrent_constraint = constraints.get(recurrent_constraint) 423 self.bias_constraint = constraints.get(bias_constraint) 424 425 @property 426 def cell(self): 427 return self._cell 428 429 def build(self, input_shape): 430 super(CuDNNLSTM, self).build(input_shape) 431 if isinstance(input_shape, list): 432 input_shape = input_shape[0] 433 input_dim = int(input_shape[-1]) 434 435 self.kernel = self.add_weight( 436 shape=(input_dim, self.units * 4), 437 name='kernel', 438 initializer=self.kernel_initializer, 439 regularizer=self.kernel_regularizer, 440 constraint=self.kernel_constraint) 441 442 self.recurrent_kernel = self.add_weight( 443 shape=(self.units, self.units * 4), 444 name='recurrent_kernel', 445 initializer=self.recurrent_initializer, 446 regularizer=self.recurrent_regularizer, 447 constraint=self.recurrent_constraint) 448 449 if self.unit_forget_bias: 450 451 def bias_initializer(_, *args, **kwargs): 452 return array_ops.concat([ 453 self.bias_initializer((self.units * 5,), *args, **kwargs), 454 initializers.Ones()((self.units,), *args, **kwargs), 455 self.bias_initializer((self.units * 2,), *args, **kwargs), 456 ], axis=0) 457 else: 458 bias_initializer = self.bias_initializer 459 self.bias = self.add_weight( 460 shape=(self.units * 8,), 461 name='bias', 462 initializer=bias_initializer, 463 regularizer=self.bias_regularizer, 464 constraint=self.bias_constraint) 465 466 self.built = True 467 468 def _process_batch(self, inputs, initial_state): 469 if not self.time_major: 470 inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) 471 input_h = initial_state[0] 472 input_c = initial_state[1] 473 input_h = array_ops.expand_dims(input_h, axis=0) 474 input_c = array_ops.expand_dims(input_c, axis=0) 475 476 params = recurrent_v2._canonical_to_params( # pylint: disable=protected-access 477 weights=[ 478 self.kernel[:, :self.units], 479 self.kernel[:, self.units:self.units * 2], 480 self.kernel[:, self.units * 2:self.units * 3], 481 self.kernel[:, self.units * 3:], 482 self.recurrent_kernel[:, :self.units], 483 self.recurrent_kernel[:, self.units:self.units * 2], 484 self.recurrent_kernel[:, self.units * 2:self.units * 3], 485 self.recurrent_kernel[:, self.units * 3:], 486 ], 487 biases=[ 488 self.bias[:self.units], 489 self.bias[self.units:self.units * 2], 490 self.bias[self.units * 2:self.units * 3], 491 self.bias[self.units * 3:self.units * 4], 492 self.bias[self.units * 4:self.units * 5], 493 self.bias[self.units * 5:self.units * 6], 494 self.bias[self.units * 6:self.units * 7], 495 self.bias[self.units * 7:], 496 ], 497 shape=self._vector_shape) 498 499 args = { 500 'input': inputs, 501 'input_h': input_h, 502 'input_c': input_c, 503 'params': params, 504 'is_training': True, 505 } 506 507 outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV2(**args) 508 509 if self.stateful or self.return_state: 510 h = h[0] 511 c = c[0] 512 if self.return_sequences: 513 if self.time_major: 514 output = outputs 515 else: 516 output = array_ops.transpose(outputs, perm=(1, 0, 2)) 517 else: 518 output = outputs[-1] 519 return output, [h, c] 520 521 def get_config(self): 522 config = { 523 'units': self.units, 524 'kernel_initializer': initializers.serialize(self.kernel_initializer), 525 'recurrent_initializer': 526 initializers.serialize(self.recurrent_initializer), 527 'bias_initializer': initializers.serialize(self.bias_initializer), 528 'unit_forget_bias': self.unit_forget_bias, 529 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 530 'recurrent_regularizer': 531 regularizers.serialize(self.recurrent_regularizer), 532 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 533 'activity_regularizer': 534 regularizers.serialize(self.activity_regularizer), 535 'kernel_constraint': constraints.serialize(self.kernel_constraint), 536 'recurrent_constraint': 537 constraints.serialize(self.recurrent_constraint), 538 'bias_constraint': constraints.serialize(self.bias_constraint) 539 } 540 base_config = super(CuDNNLSTM, self).get_config() 541 return dict(list(base_config.items()) + list(config.items())) 542