1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module contains the implementation of RNN cell wrappers.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import hashlib 21import numbers 22import sys 23import types as python_types 24import warnings 25 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.keras.utils import generic_utils 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import nn_ops 33from tensorflow.python.ops import random_ops 34from tensorflow.python.ops import tensor_array_ops 35from tensorflow.python.util import nest 36 37 38class DropoutWrapperBase(object): 39 """Operator adding dropout to inputs and outputs of the given cell.""" 40 41 def __init__(self, 42 cell, 43 input_keep_prob=1.0, 44 output_keep_prob=1.0, 45 state_keep_prob=1.0, 46 variational_recurrent=False, 47 input_size=None, 48 dtype=None, 49 seed=None, 50 dropout_state_filter_visitor=None, 51 **kwargs): 52 """Create a cell with added input, state, and/or output dropout. 53 54 If `variational_recurrent` is set to `True` (**NOT** the default behavior), 55 then the same dropout mask is applied at every step, as described in: 56 [A Theoretically Grounded Application of Dropout in Recurrent 57 Neural Networks. Y. Gal, Z. Ghahramani](https://arxiv.org/abs/1512.05287). 58 59 Otherwise a different dropout mask is applied at every time step. 60 61 Note, by default (unless a custom `dropout_state_filter` is provided), 62 the memory state (`c` component of any `LSTMStateTuple`) passing through 63 a `DropoutWrapper` is never modified. This behavior is described in the 64 above article. 65 66 Args: 67 cell: an RNNCell, a projection to output_size is added to it. 68 input_keep_prob: unit Tensor or float between 0 and 1, input keep 69 probability; if it is constant and 1, no input dropout will be added. 70 output_keep_prob: unit Tensor or float between 0 and 1, output keep 71 probability; if it is constant and 1, no output dropout will be added. 72 state_keep_prob: unit Tensor or float between 0 and 1, output keep 73 probability; if it is constant and 1, no output dropout will be added. 74 State dropout is performed on the outgoing states of the cell. **Note** 75 the state components to which dropout is applied when `state_keep_prob` 76 is in `(0, 1)` are also determined by the argument 77 `dropout_state_filter_visitor` (e.g. by default dropout is never applied 78 to the `c` component of an `LSTMStateTuple`). 79 variational_recurrent: Python bool. If `True`, then the same dropout 80 pattern is applied across all time steps per run call. If this parameter 81 is set, `input_size` **must** be provided. 82 input_size: (optional) (possibly nested tuple of) `TensorShape` objects 83 containing the depth(s) of the input tensors expected to be passed in to 84 the `DropoutWrapper`. Required and used **iff** `variational_recurrent 85 = True` and `input_keep_prob < 1`. 86 dtype: (optional) The `dtype` of the input, state, and output tensors. 87 Required and used **iff** `variational_recurrent = True`. 88 seed: (optional) integer, the randomness seed. 89 dropout_state_filter_visitor: (optional), default: (see below). Function 90 that takes any hierarchical level of the state and returns a scalar or 91 depth=1 structure of Python booleans describing which terms in the state 92 should be dropped out. In addition, if the function returns `True`, 93 dropout is applied across this sublevel. If the function returns 94 `False`, dropout is not applied across this entire sublevel. 95 Default behavior: perform dropout on all terms except the memory (`c`) 96 state of `LSTMCellState` objects, and don't try to apply dropout to 97 `TensorArray` objects: ``` 98 def dropout_state_filter_visitor(s): 99 if isinstance(s, LSTMCellState): # Never perform dropout on the c 100 state. return LSTMCellState(c=False, h=True) 101 elif isinstance(s, TensorArray): return False return True ``` 102 **kwargs: dict of keyword arguments for base layer. 103 104 Raises: 105 TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided 106 but not `callable`. 107 ValueError: if any of the keep_probs are not between 0 and 1. 108 """ 109 super(DropoutWrapperBase, self).__init__(cell, dtype=dtype, **kwargs) 110 111 if (dropout_state_filter_visitor is not None and 112 not callable(dropout_state_filter_visitor)): 113 raise TypeError("dropout_state_filter_visitor must be callable") 114 self._dropout_state_filter = ( 115 dropout_state_filter_visitor or _default_dropout_state_filter_visitor) 116 with ops.name_scope_v2("DropoutWrapperInit"): 117 118 def tensor_and_const_value(v): 119 tensor_value = ops.convert_to_tensor_v2_with_dispatch(v) 120 const_value = tensor_util.constant_value(tensor_value) 121 return (tensor_value, const_value) 122 123 for prob, attr in [(input_keep_prob, "input_keep_prob"), 124 (state_keep_prob, "state_keep_prob"), 125 (output_keep_prob, "output_keep_prob")]: 126 tensor_prob, const_prob = tensor_and_const_value(prob) 127 if const_prob is not None: 128 if const_prob < 0 or const_prob > 1: 129 raise ValueError("Parameter %s must be between 0 and 1: %d" % 130 (attr, const_prob)) 131 setattr(self, "_%s" % attr, float(const_prob)) 132 else: 133 setattr(self, "_%s" % attr, tensor_prob) 134 135 # Set variational_recurrent, seed before running the code below 136 self._variational_recurrent = variational_recurrent 137 self._input_size = input_size 138 self._seed = seed 139 140 self._recurrent_input_noise = None 141 self._recurrent_state_noise = None 142 self._recurrent_output_noise = None 143 144 if variational_recurrent: 145 if dtype is None: 146 raise ValueError( 147 "When variational_recurrent=True, dtype must be provided") 148 149 def convert_to_batch_shape(s): 150 # Prepend a 1 for the batch dimension; for recurrent 151 # variational dropout we use the same dropout mask for all 152 # batch elements. 153 return array_ops.concat(([1], tensor_shape.TensorShape(s).as_list()), 0) 154 155 def batch_noise(s, inner_seed): 156 shape = convert_to_batch_shape(s) 157 return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) 158 159 if (not isinstance(self._input_keep_prob, numbers.Real) or 160 self._input_keep_prob < 1.0): 161 if input_size is None: 162 raise ValueError( 163 "When variational_recurrent=True and input_keep_prob < 1.0 or " 164 "is unknown, input_size must be provided") 165 self._recurrent_input_noise = _enumerated_map_structure_up_to( 166 input_size, 167 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), 168 input_size) 169 self._recurrent_state_noise = _enumerated_map_structure_up_to( 170 cell.state_size, 171 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), 172 cell.state_size) 173 self._recurrent_output_noise = _enumerated_map_structure_up_to( 174 cell.output_size, 175 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), 176 cell.output_size) 177 178 def _gen_seed(self, salt_prefix, index): 179 if self._seed is None: 180 return None 181 salt = "%s_%d" % (salt_prefix, index) 182 string = (str(self._seed) + salt).encode("utf-8") 183 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 184 185 @property 186 def wrapped_cell(self): 187 return self.cell 188 189 @property 190 def state_size(self): 191 return self.cell.state_size 192 193 @property 194 def output_size(self): 195 return self.cell.output_size 196 197 def build(self, inputs_shape): 198 self.cell.build(inputs_shape) 199 self.built = True 200 201 def zero_state(self, batch_size, dtype): 202 with ops.name_scope_v2(type(self).__name__ + "ZeroState"): 203 return self.cell.zero_state(batch_size, dtype) 204 205 def _variational_recurrent_dropout_value( 206 self, unused_index, value, noise, keep_prob): 207 """Performs dropout given the pre-calculated noise tensor.""" 208 # uniform [keep_prob, 1.0 + keep_prob) 209 random_tensor = keep_prob + noise 210 211 # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) 212 binary_tensor = math_ops.floor(random_tensor) 213 ret = math_ops.divide(value, keep_prob) * binary_tensor 214 ret.set_shape(value.get_shape()) 215 return ret 216 217 def _dropout(self, 218 values, 219 salt_prefix, 220 recurrent_noise, 221 keep_prob, 222 shallow_filtered_substructure=None): 223 """Decides whether to perform standard dropout or recurrent dropout.""" 224 225 if shallow_filtered_substructure is None: 226 # Put something so we traverse the entire structure; inside the 227 # dropout function we check to see if leafs of this are bool or not. 228 shallow_filtered_substructure = values 229 230 if not self._variational_recurrent: 231 232 def dropout(i, do_dropout, v): 233 if not isinstance(do_dropout, bool) or do_dropout: 234 return nn_ops.dropout_v2( 235 v, rate=1. - keep_prob, seed=self._gen_seed(salt_prefix, i)) 236 else: 237 return v 238 239 return _enumerated_map_structure_up_to( 240 shallow_filtered_substructure, dropout, 241 *[shallow_filtered_substructure, values]) 242 else: 243 244 def dropout(i, do_dropout, v, n): 245 if not isinstance(do_dropout, bool) or do_dropout: 246 return self._variational_recurrent_dropout_value(i, v, n, keep_prob) 247 else: 248 return v 249 250 return _enumerated_map_structure_up_to( 251 shallow_filtered_substructure, dropout, 252 *[shallow_filtered_substructure, values, recurrent_noise]) 253 254 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 255 """Runs the wrapped cell and applies dropout. 256 257 Args: 258 inputs: A tensor with wrapped cell's input. 259 state: A tensor or tuple of tensors with wrapped cell's state. 260 cell_call_fn: Wrapped cell's method to use for step computation (cell's 261 `__call__` or 'call' method). 262 **kwargs: Additional arguments. 263 264 Returns: 265 A pair containing: 266 267 - Output: A tensor with cell's output. 268 - New state: A tensor or tuple of tensors with new wrapped cell's state. 269 """ 270 271 def _should_dropout(p): 272 return (not isinstance(p, float)) or p < 1 273 274 if _should_dropout(self._input_keep_prob): 275 inputs = self._dropout(inputs, "input", self._recurrent_input_noise, 276 self._input_keep_prob) 277 output, new_state = cell_call_fn(inputs, state, **kwargs) 278 if _should_dropout(self._state_keep_prob): 279 # Identify which subsets of the state to perform dropout on and 280 # which ones to keep. 281 shallow_filtered_substructure = nest.get_traverse_shallow_structure( 282 self._dropout_state_filter, new_state) 283 new_state = self._dropout(new_state, "state", self._recurrent_state_noise, 284 self._state_keep_prob, 285 shallow_filtered_substructure) 286 if _should_dropout(self._output_keep_prob): 287 output = self._dropout(output, "output", self._recurrent_output_noise, 288 self._output_keep_prob) 289 return output, new_state 290 291 def get_config(self): 292 """Returns the config of the dropout wrapper.""" 293 config = { 294 "input_keep_prob": self._input_keep_prob, 295 "output_keep_prob": self._output_keep_prob, 296 "state_keep_prob": self._state_keep_prob, 297 "variational_recurrent": self._variational_recurrent, 298 "input_size": self._input_size, 299 "seed": self._seed, 300 } 301 if self._dropout_state_filter != _default_dropout_state_filter_visitor: 302 function, function_type, function_module = _serialize_function_to_config( 303 self._dropout_state_filter) 304 config.update({"dropout_fn": function, 305 "dropout_fn_type": function_type, 306 "dropout_fn_module": function_module}) 307 base_config = super(DropoutWrapperBase, self).get_config() 308 return dict(list(base_config.items()) + list(config.items())) 309 310 @classmethod 311 def from_config(cls, config, custom_objects=None): 312 if "dropout_fn" in config: 313 config = config.copy() 314 dropout_state_filter = _parse_config_to_function( 315 config, custom_objects, "dropout_fn", "dropout_fn_type", 316 "dropout_fn_module") 317 config.pop("dropout_fn") 318 config["dropout_state_filter_visitor"] = dropout_state_filter 319 return super(DropoutWrapperBase, cls).from_config( 320 config, custom_objects=custom_objects) 321 322 323class ResidualWrapperBase(object): 324 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 325 326 def __init__(self, cell, residual_fn=None, **kwargs): 327 """Constructs a `ResidualWrapper` for `cell`. 328 329 Args: 330 cell: An instance of `RNNCell`. 331 residual_fn: (Optional) The function to map raw cell inputs and raw cell 332 outputs to the actual cell outputs of the residual network. 333 Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs 334 and outputs. 335 **kwargs: dict of keyword arguments for base layer. 336 """ 337 super(ResidualWrapperBase, self).__init__(cell, **kwargs) 338 self._residual_fn = residual_fn 339 340 @property 341 def state_size(self): 342 return self.cell.state_size 343 344 @property 345 def output_size(self): 346 return self.cell.output_size 347 348 def zero_state(self, batch_size, dtype): 349 with ops.name_scope_v2(type(self).__name__ + "ZeroState"): 350 return self.cell.zero_state(batch_size, dtype) 351 352 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 353 """Run the cell and then apply the residual_fn on its inputs to its outputs. 354 355 Args: 356 inputs: cell inputs. 357 state: cell state. 358 cell_call_fn: Wrapped cell's method to use for step computation (cell's 359 `__call__` or 'call' method). 360 **kwargs: Additional arguments passed to the wrapped cell's `call`. 361 362 Returns: 363 Tuple of cell outputs and new state. 364 365 Raises: 366 TypeError: If cell inputs and outputs have different structure (type). 367 ValueError: If cell inputs and outputs have different structure (value). 368 """ 369 outputs, new_state = cell_call_fn(inputs, state, **kwargs) 370 371 # Ensure shapes match 372 def assert_shape_match(inp, out): 373 inp.get_shape().assert_is_compatible_with(out.get_shape()) 374 375 def default_residual_fn(inputs, outputs): 376 nest.assert_same_structure(inputs, outputs) 377 nest.map_structure(assert_shape_match, inputs, outputs) 378 return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) 379 380 res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) 381 return (res_outputs, new_state) 382 383 def get_config(self): 384 """Returns the config of the residual wrapper.""" 385 if self._residual_fn is not None: 386 function, function_type, function_module = _serialize_function_to_config( 387 self._residual_fn) 388 config = { 389 "residual_fn": function, 390 "residual_fn_type": function_type, 391 "residual_fn_module": function_module 392 } 393 else: 394 config = {} 395 base_config = super(ResidualWrapperBase, self).get_config() 396 return dict(list(base_config.items()) + list(config.items())) 397 398 @classmethod 399 def from_config(cls, config, custom_objects=None): 400 if "residual_fn" in config: 401 config = config.copy() 402 residual_function = _parse_config_to_function(config, custom_objects, 403 "residual_fn", 404 "residual_fn_type", 405 "residual_fn_module") 406 config["residual_fn"] = residual_function 407 return super(ResidualWrapperBase, cls).from_config( 408 config, custom_objects=custom_objects) 409 410 411class DeviceWrapperBase(object): 412 """Operator that ensures an RNNCell runs on a particular device.""" 413 414 def __init__(self, cell, device, **kwargs): 415 """Construct a `DeviceWrapper` for `cell` with device `device`. 416 417 Ensures the wrapped `cell` is called with `tf.device(device)`. 418 419 Args: 420 cell: An instance of `RNNCell`. 421 device: A device string or function, for passing to `tf.device`. 422 **kwargs: dict of keyword arguments for base layer. 423 """ 424 super(DeviceWrapperBase, self).__init__(cell, **kwargs) 425 self._device = device 426 427 @property 428 def state_size(self): 429 return self.cell.state_size 430 431 @property 432 def output_size(self): 433 return self.cell.output_size 434 435 def zero_state(self, batch_size, dtype): 436 with ops.name_scope_v2(type(self).__name__ + "ZeroState"): 437 with ops.device(self._device): 438 return self.cell.zero_state(batch_size, dtype) 439 440 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 441 """Run the cell on specified device.""" 442 with ops.device(self._device): 443 return cell_call_fn(inputs, state, **kwargs) 444 445 def get_config(self): 446 config = {"device": self._device} 447 base_config = super(DeviceWrapperBase, self).get_config() 448 return dict(list(base_config.items()) + list(config.items())) 449 450 451def _serialize_function_to_config(function): 452 """Serialize the function for get_config().""" 453 if isinstance(function, python_types.LambdaType): 454 output = generic_utils.func_dump(function) 455 output_type = "lambda" 456 module = function.__module__ 457 elif callable(function): 458 output = function.__name__ 459 output_type = "function" 460 module = function.__module__ 461 else: 462 raise ValueError("Unrecognized function type for input: {}".format( 463 type(function))) 464 465 return output, output_type, module 466 467 468def _parse_config_to_function(config, custom_objects, func_attr_name, 469 func_type_attr_name, module_attr_name): 470 """Reconstruct the function from the config.""" 471 globs = globals() 472 module = config.pop(module_attr_name, None) 473 if module in sys.modules: 474 globs.update(sys.modules[module].__dict__) 475 elif module is not None: 476 # Note: we don't know the name of the function if it's a lambda. 477 warnings.warn("{} is not loaded, but a layer uses it. " 478 "It may cause errors.".format(module), UserWarning) 479 if custom_objects: 480 globs.update(custom_objects) 481 function_type = config.pop(func_type_attr_name) 482 if function_type == "function": 483 # Simple lookup in custom objects 484 function = generic_utils.deserialize_keras_object( 485 config[func_attr_name], 486 custom_objects=custom_objects, 487 printable_module_name="function in wrapper") 488 elif function_type == "lambda": 489 # Unsafe deserialization from bytecode 490 function = generic_utils.func_load( 491 config[func_attr_name], globs=globs) 492 else: 493 raise TypeError("Unknown function type:", function_type) 494 return function 495 496 497def _default_dropout_state_filter_visitor(substate): 498 from tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl import LSTMStateTuple # pylint: disable=g-import-not-at-top 499 if isinstance(substate, LSTMStateTuple): 500 # Do not perform dropout on the memory state. 501 return LSTMStateTuple(c=False, h=True) 502 elif isinstance(substate, tensor_array_ops.TensorArray): 503 return False 504 return True 505 506 507def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): 508 ix = [0] 509 510 def enumerated_fn(*inner_args, **inner_kwargs): 511 r = map_fn(ix[0], *inner_args, **inner_kwargs) 512 ix[0] += 1 513 return r 514 515 return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args, 516 **kwargs) 517