1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Wrapper layers: layers that augment the functionality of another layer. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import copy 23 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.keras import backend as K 26from tensorflow.python.keras.engine.base_layer import Layer 27from tensorflow.python.keras.engine.input_spec import InputSpec 28from tensorflow.python.keras.layers.recurrent import _standardize_args 29from tensorflow.python.keras.utils import generic_utils 30from tensorflow.python.keras.utils import layer_utils 31from tensorflow.python.keras.utils import tf_utils 32from tensorflow.python.ops import array_ops 33from tensorflow.python.util import nest 34from tensorflow.python.util.tf_export import keras_export 35 36 37@keras_export('keras.layers.Wrapper') 38class Wrapper(Layer): 39 """Abstract wrapper base class. 40 41 Wrappers take another layer and augment it in various ways. 42 Do not use this class as a layer, it is only an abstract base class. 43 Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. 44 45 Arguments: 46 layer: The layer to be wrapped. 47 """ 48 49 def __init__(self, layer, **kwargs): 50 assert isinstance(layer, Layer) 51 self.layer = layer 52 # Tracks mapping of Wrapper inputs to inner layer inputs. Useful when 53 # the inner layer has update ops that depend on its inputs (as opposed 54 # to the inputs to the Wrapper layer). 55 self._input_map = {} 56 super(Wrapper, self).__init__(**kwargs) 57 58 def build(self, input_shape=None): 59 self.built = True 60 61 @property 62 def activity_regularizer(self): 63 if hasattr(self.layer, 'activity_regularizer'): 64 return self.layer.activity_regularizer 65 else: 66 return None 67 68 def get_config(self): 69 config = { 70 'layer': { 71 'class_name': self.layer.__class__.__name__, 72 'config': self.layer.get_config() 73 } 74 } 75 base_config = super(Wrapper, self).get_config() 76 return dict(list(base_config.items()) + list(config.items())) 77 78 @classmethod 79 def from_config(cls, config, custom_objects=None): 80 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 81 layer = deserialize_layer( 82 config.pop('layer'), custom_objects=custom_objects) 83 return cls(layer, **config) 84 85 86@keras_export('keras.layers.TimeDistributed') 87class TimeDistributed(Wrapper): 88 """This wrapper allows to apply a layer to every temporal slice of an input. 89 90 The input should be at least 3D, and the dimension of index one 91 will be considered to be the temporal dimension. 92 93 Consider a batch of 32 samples, 94 where each sample is a sequence of 10 vectors of 16 dimensions. 95 The batch input shape of the layer is then `(32, 10, 16)`, 96 and the `input_shape`, not including the samples dimension, is `(10, 16)`. 97 98 You can then use `TimeDistributed` to apply a `Dense` layer 99 to each of the 10 timesteps, independently: 100 101 ```python 102 # as the first layer in a model 103 model = Sequential() 104 model.add(TimeDistributed(Dense(8), input_shape=(10, 16))) 105 # now model.output_shape == (None, 10, 8) 106 ``` 107 108 The output will then have shape `(32, 10, 8)`. 109 110 In subsequent layers, there is no need for the `input_shape`: 111 112 ```python 113 model.add(TimeDistributed(Dense(32))) 114 # now model.output_shape == (None, 10, 32) 115 ``` 116 117 The output will then have shape `(32, 10, 32)`. 118 119 `TimeDistributed` can be used with arbitrary layers, not just `Dense`, 120 for instance with a `Conv2D` layer: 121 122 ```python 123 model = Sequential() 124 model.add(TimeDistributed(Conv2D(64, (3, 3)), 125 input_shape=(10, 299, 299, 3))) 126 ``` 127 128 Arguments: 129 layer: a layer instance. 130 131 Call arguments: 132 inputs: Input tensor. 133 training: Python boolean indicating whether the layer should behave in 134 training mode or in inference mode. This argument is passed to the 135 wrapped layer (only if the layer supports this argument). 136 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 137 a given timestep should be masked. This argument is passed to the 138 wrapped layer (only if the layer supports this argument). 139 140 Raises: 141 ValueError: If not initialized with a `Layer` instance. 142 """ 143 144 def __init__(self, layer, **kwargs): 145 if not isinstance(layer, Layer): 146 raise ValueError( 147 'Please initialize `TimeDistributed` layer with a ' 148 '`Layer` instance. You passed: {input}'.format(input=layer)) 149 super(TimeDistributed, self).__init__(layer, **kwargs) 150 self.supports_masking = True 151 152 # It is safe to use the fast, reshape-based approach with all of our 153 # built-in Layers. 154 self._always_use_reshape = ( 155 layer_utils.is_builtin_layer(layer) and 156 not getattr(layer, 'stateful', False)) 157 158 def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None): 159 """Finds non-specific dimensions in the static shapes. 160 161 The static shapes are replaced with the corresponding dynamic shapes of the 162 tensor. 163 164 Arguments: 165 init_tuple: a tuple, the first part of the output shape 166 tensor: the tensor from which to get the (static and dynamic) shapes 167 as the last part of the output shape 168 start_idx: int, which indicate the first dimension to take from 169 the static shape of the tensor 170 int_shape: an alternative static shape to take as the last part 171 of the output shape 172 173 Returns: 174 The new int_shape with the first part from init_tuple 175 and the last part from either `int_shape` (if provided) 176 or `tensor.shape`, where every `None` is replaced by 177 the corresponding dimension from `tf.shape(tensor)`. 178 """ 179 # replace all None in int_shape by K.shape 180 if int_shape is None: 181 int_shape = K.int_shape(tensor)[start_idx:] 182 if not any(not s for s in int_shape): 183 return init_tuple + tuple(int_shape) 184 shape = K.shape(tensor) 185 int_shape = list(int_shape) 186 for i, s in enumerate(int_shape): 187 if not s: 188 int_shape[i] = shape[start_idx + i] 189 return init_tuple + tuple(int_shape) 190 191 def build(self, input_shape): 192 input_shape = tensor_shape.TensorShape(input_shape).as_list() 193 if len(input_shape) < 3: 194 raise ValueError( 195 '`TimeDistributed` Layer should be passed an `input_shape ` ' 196 'with at least 3 dimensions, received: ' + str(input_shape)) 197 # Don't enforce the batch or time dimension. 198 self.input_spec = InputSpec(shape=[None, None] + input_shape[2:]) 199 child_input_shape = [input_shape[0]] + input_shape[2:] 200 if not self.layer.built: 201 # The base layer class calls a conversion function on the input shape to 202 # convert it to a TensorShape. The conversion function requires a 203 # tuple which is why we cast the shape. 204 self.layer.build(tuple(child_input_shape)) 205 self.layer.built = True 206 super(TimeDistributed, self).build() 207 self.built = True 208 209 def compute_output_shape(self, input_shape): 210 input_shape = tensor_shape.TensorShape(input_shape).as_list() 211 child_input_shape = tensor_shape.TensorShape([input_shape[0]] + 212 input_shape[2:]) 213 child_output_shape = self.layer.compute_output_shape( 214 child_input_shape).as_list() 215 timesteps = input_shape[1] 216 return tensor_shape.TensorShape([child_output_shape[0], timesteps] + 217 child_output_shape[1:]) 218 219 def call(self, inputs, training=None, mask=None): 220 kwargs = {} 221 if generic_utils.has_arg(self.layer.call, 'training'): 222 kwargs['training'] = training 223 224 input_shape = K.int_shape(inputs) 225 if input_shape[0] and not self._always_use_reshape: 226 # batch size matters, use rnn-based implementation 227 def step(x, _): 228 output = self.layer.call(x, **kwargs) 229 return output, [] 230 231 _, outputs, _ = K.rnn( 232 step, 233 inputs, 234 initial_states=[], 235 input_length=input_shape[1], 236 unroll=False) 237 y = outputs 238 else: 239 # No batch size specified, therefore the layer will be able 240 # to process batches of any size. 241 # We can go with reshape-based implementation for performance. 242 input_length = input_shape[1] 243 if not input_length: 244 input_length = array_ops.shape(inputs)[1] 245 inner_input_shape = self._get_shape_tuple((-1,), inputs, 2) 246 # Shape: (num_samples * timesteps, ...). And track the 247 # transformation in self._input_map. 248 input_uid = generic_utils.object_list_uid(inputs) 249 inputs = array_ops.reshape(inputs, inner_input_shape) 250 self._input_map[input_uid] = inputs 251 # (num_samples * timesteps, ...) 252 if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None: 253 inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) 254 kwargs['mask'] = K.reshape(mask, inner_mask_shape) 255 y = self.layer.call(inputs, **kwargs) 256 # Shape: (num_samples, timesteps, ...) 257 output_shape = self.compute_output_shape(input_shape).as_list() 258 output_shape = self._get_shape_tuple( 259 (-1, input_length), y, 1, output_shape[2:]) 260 y = array_ops.reshape(y, output_shape) 261 262 # Apply activity regularizer if any: 263 if (hasattr(self.layer, 'activity_regularizer') and 264 self.layer.activity_regularizer is not None): 265 regularization_loss = self.layer.activity_regularizer(y) 266 self.add_loss(regularization_loss, inputs) 267 return y 268 269 def compute_mask(self, inputs, mask=None): 270 """Computes an output mask tensor for Embedding layer. 271 272 This is based on the inputs, mask, and the inner layer. 273 If batch size is specified: 274 Simply return the input `mask`. (An rnn-based implementation with 275 more than one rnn inputs is required but not supported in tf.keras yet.) 276 Otherwise we call `compute_mask` of the inner layer at each time step. 277 If the output mask at each time step is not `None`: 278 (E.g., inner layer is Masking or RNN) 279 Concatenate all of them and return the concatenation. 280 If the output mask at each time step is `None` and the input mask is not 281 `None`:(E.g., inner layer is Dense) 282 Reduce the input_mask to 2 dimensions and return it. 283 Otherwise (both the output mask and the input mask are `None`): 284 (E.g., `mask` is not used at all) 285 Return `None`. 286 287 Arguments: 288 inputs: Tensor with shape [batch size, timesteps, ...] indicating the 289 input to TimeDistributed. If static shape information is available for 290 "batch size", `mask` is returned unmodified. 291 mask: Either None (indicating no masking) or a Tensor indicating the 292 input mask for TimeDistributed. The shape can be static or dynamic. 293 294 Returns: 295 Either None (no masking), or a [batch size, timesteps, ...] Tensor with 296 an output mask for the TimeDistributed layer with the shape beyond the 297 second dimension being the value of the input mask shape(if the computed 298 output mask is none), an output mask with the shape beyond the first 299 dimension being the value of the mask shape(if mask is not None) or 300 output mask with the shape beyond the first dimension being the 301 value of the computed output shape. 302 303 """ 304 # cases need to call the layer.compute_mask when input_mask is None: 305 # Masking layer and Embedding layer with mask_zero 306 input_shape = K.int_shape(inputs) 307 if input_shape[0]: 308 # batch size matters, we currently do not handle mask explicitly 309 return mask 310 inner_mask = mask 311 if inner_mask is not None: 312 inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) 313 inner_mask = K.reshape(inner_mask, inner_mask_shape) 314 input_uid = generic_utils.object_list_uid(inputs) 315 inner_inputs = self._input_map.get(input_uid, inputs) 316 output_mask = self.layer.compute_mask(inner_inputs, inner_mask) 317 if output_mask is None: 318 if mask is None: 319 return None 320 # input_mask is not None, and output_mask is None: 321 # we should return a not-None mask 322 output_mask = mask 323 for _ in range(2, len(K.int_shape(mask))): 324 output_mask = K.any(output_mask, axis=-1) 325 else: 326 # output_mask is not None. We need to reshape it 327 input_length = input_shape[1] 328 if not input_length: 329 input_length = K.shape(inputs)[1] 330 output_mask_int_shape = K.int_shape(output_mask) 331 if output_mask_int_shape is None: 332 # if the output_mask does not have a static shape, 333 # its shape must be the same as mask's 334 if mask is not None: 335 output_mask_int_shape = K.int_shape(mask) 336 else: 337 output_mask_int_shape = K.compute_output_shape(input_shape)[:-1] 338 output_mask_shape = self._get_shape_tuple( 339 (-1, input_length), output_mask, 1, output_mask_int_shape[1:]) 340 output_mask = K.reshape(output_mask, output_mask_shape) 341 return output_mask 342 343 344@keras_export('keras.layers.Bidirectional') 345class Bidirectional(Wrapper): 346 """Bidirectional wrapper for RNNs. 347 348 Arguments: 349 layer: `Recurrent` instance. 350 merge_mode: Mode by which outputs of the 351 forward and backward RNNs will be combined. 352 One of {'sum', 'mul', 'concat', 'ave', None}. 353 If None, the outputs will not be combined, 354 they will be returned as a list. 355 356 Call arguments: 357 The call arguments for this layer are the same as those of the wrapped RNN 358 layer. 359 360 Raises: 361 ValueError: If not initialized with a `Layer` instance or 362 In case of invalid `merge_mode` argument. 363 364 Examples: 365 366 ```python 367 model = Sequential() 368 model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 369 10))) 370 model.add(Bidirectional(LSTM(10))) 371 model.add(Dense(5)) 372 model.add(Activation('softmax')) 373 model.compile(loss='categorical_crossentropy', optimizer='rmsprop') 374 ``` 375 """ 376 377 def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): 378 if not isinstance(layer, Layer): 379 raise ValueError( 380 'Please initialize `Bidirectional` layer with a ' 381 '`Layer` instance. You passed: {input}'.format(input=layer)) 382 if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]: 383 raise ValueError('Invalid merge mode. ' 384 'Merge mode should be one of ' 385 '{"sum", "mul", "ave", "concat", None}') 386 if getattr(layer, 'zero_output_for_mask', None) is not None: 387 # Force the zero_output_for_mask to be True if returning sequences. 388 layer.zero_output_for_mask = layer.return_sequences 389 390 self.forward_layer = copy.copy(layer) 391 config = layer.get_config() 392 config['go_backwards'] = not config['go_backwards'] 393 self.backward_layer = layer.__class__.from_config(config) 394 self.forward_layer._name = 'forward_' + self.forward_layer.name 395 self.backward_layer._name = 'backward_' + self.backward_layer.name 396 self.merge_mode = merge_mode 397 if weights: 398 nw = len(weights) 399 self.forward_layer.initial_weights = weights[:nw // 2] 400 self.backward_layer.initial_weights = weights[nw // 2:] 401 self.stateful = layer.stateful 402 self.return_sequences = layer.return_sequences 403 self.return_state = layer.return_state 404 self.supports_masking = True 405 self._trainable = True 406 self._num_constants = None 407 # We don't want to track `layer` since we're already tracking the two copies 408 # of it we actually run. 409 self._setattr_tracking = False 410 super(Bidirectional, self).__init__(layer, **kwargs) 411 self._setattr_tracking = True 412 self.input_spec = layer.input_spec 413 414 @tf_utils.shape_type_conversion 415 def compute_output_shape(self, input_shape): 416 output_shape = tuple(self.forward_layer.compute_output_shape( 417 input_shape).as_list()) 418 if self.return_state: 419 state_shape = output_shape[1:] 420 output_shape = output_shape[0] 421 422 if self.merge_mode == 'concat': 423 output_shape = list(output_shape) 424 output_shape[-1] *= 2 425 output_shape = tuple(output_shape) 426 elif self.merge_mode is None: 427 output_shape = [output_shape, copy.copy(output_shape)] 428 429 if self.return_state: 430 if self.merge_mode is None: 431 return output_shape + state_shape + copy.copy(state_shape) 432 return [output_shape] + state_shape + copy.copy(state_shape) 433 return output_shape 434 435 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 436 """`Bidirectional.__call__` implements the same API as the wrapped `RNN`.""" 437 inputs, initial_state, constants = _standardize_args( 438 inputs, initial_state, constants, self._num_constants) 439 440 if isinstance(inputs, list): 441 if len(inputs) > 1: 442 initial_state = inputs[1:] 443 inputs = inputs[0] 444 445 if initial_state is None and constants is None: 446 return super(Bidirectional, self).__call__(inputs, **kwargs) 447 448 # Applies the same workaround as in `RNN.__call__` 449 additional_inputs = [] 450 additional_specs = [] 451 if initial_state is not None: 452 # Check if `initial_state` can be splitted into half 453 num_states = len(initial_state) 454 if num_states % 2 > 0: 455 raise ValueError( 456 'When passing `initial_state` to a Bidirectional RNN, ' 457 'the state should be a list containing the states of ' 458 'the underlying RNNs. ' 459 'Found: ' + str(initial_state)) 460 461 kwargs['initial_state'] = initial_state 462 additional_inputs += initial_state 463 state_specs = [InputSpec(shape=K.int_shape(state)) 464 for state in initial_state] 465 self.forward_layer.state_spec = state_specs[:num_states // 2] 466 self.backward_layer.state_spec = state_specs[num_states // 2:] 467 additional_specs += state_specs 468 if constants is not None: 469 kwargs['constants'] = constants 470 additional_inputs += constants 471 constants_spec = [InputSpec(shape=K.int_shape(constant)) 472 for constant in constants] 473 self.forward_layer.constants_spec = constants_spec 474 self.backward_layer.constants_spec = constants_spec 475 additional_specs += constants_spec 476 477 self._num_constants = len(constants) 478 self.forward_layer._num_constants = self._num_constants 479 self.backward_layer._num_constants = self._num_constants 480 481 is_keras_tensor = K.is_keras_tensor(additional_inputs[0]) 482 for tensor in additional_inputs: 483 if K.is_keras_tensor(tensor) != is_keras_tensor: 484 raise ValueError('The initial state of a Bidirectional' 485 ' layer cannot be specified with a mix of' 486 ' Keras tensors and non-Keras tensors' 487 ' (a "Keras tensor" is a tensor that was' 488 ' returned by a Keras layer, or by `Input`)') 489 490 if is_keras_tensor: 491 # Compute the full input spec, including state 492 full_input = [inputs] + additional_inputs 493 # The original input_spec is None since there could be a nested tensor 494 # input. Update the input_spec to match the inputs. 495 full_input_spec = [None for _ in range(len(nest.flatten(inputs))) 496 ] + additional_specs 497 498 # Perform the call with temporarily replaced input_spec 499 original_input_spec = self.input_spec 500 self.input_spec = full_input_spec 501 output = super(Bidirectional, self).__call__(full_input, **kwargs) 502 self.input_spec = original_input_spec 503 return output 504 else: 505 return super(Bidirectional, self).__call__(inputs, **kwargs) 506 507 def call(self, 508 inputs, 509 training=None, 510 mask=None, 511 initial_state=None, 512 constants=None): 513 """`Bidirectional.call` implements the same API as the wrapped `RNN`.""" 514 kwargs = {} 515 if generic_utils.has_arg(self.layer.call, 'training'): 516 kwargs['training'] = training 517 if generic_utils.has_arg(self.layer.call, 'mask'): 518 kwargs['mask'] = mask 519 if generic_utils.has_arg(self.layer.call, 'constants'): 520 kwargs['constants'] = constants 521 522 if initial_state is not None and generic_utils.has_arg( 523 self.layer.call, 'initial_state'): 524 forward_inputs = [inputs[0]] 525 backward_inputs = [inputs[0]] 526 pivot = len(initial_state) // 2 + 1 527 # add forward initial state 528 forward_state = inputs[1:pivot] 529 forward_inputs += forward_state 530 if self._num_constants is None: 531 # add backward initial state 532 backward_state = inputs[pivot:] 533 backward_inputs += backward_state 534 else: 535 # add backward initial state 536 backward_state = inputs[pivot:-self._num_constants] 537 backward_inputs += backward_state 538 # add constants for forward and backward layers 539 forward_inputs += inputs[-self._num_constants:] 540 backward_inputs += inputs[-self._num_constants:] 541 y = self.forward_layer.call(forward_inputs, 542 initial_state=forward_state, **kwargs) 543 y_rev = self.backward_layer.call(backward_inputs, 544 initial_state=backward_state, **kwargs) 545 else: 546 y = self.forward_layer.call(inputs, **kwargs) 547 y_rev = self.backward_layer.call(inputs, **kwargs) 548 549 if self.return_state: 550 states = y[1:] + y_rev[1:] 551 y = y[0] 552 y_rev = y_rev[0] 553 554 if self.return_sequences: 555 y_rev = K.reverse(y_rev, 1) 556 if self.merge_mode == 'concat': 557 output = K.concatenate([y, y_rev]) 558 elif self.merge_mode == 'sum': 559 output = y + y_rev 560 elif self.merge_mode == 'ave': 561 output = (y + y_rev) / 2 562 elif self.merge_mode == 'mul': 563 output = y * y_rev 564 elif self.merge_mode is None: 565 output = [y, y_rev] 566 else: 567 raise ValueError( 568 'Unrecognized value for `merge_mode`: %s' % (self.merge_mode)) 569 570 if self.return_state: 571 if self.merge_mode is None: 572 return output + states 573 return [output] + states 574 return output 575 576 def reset_states(self): 577 self.forward_layer.reset_states() 578 self.backward_layer.reset_states() 579 580 def build(self, input_shape): 581 with K.name_scope(self.forward_layer.name): 582 self.forward_layer.build(input_shape) 583 with K.name_scope(self.backward_layer.name): 584 self.backward_layer.build(input_shape) 585 self.built = True 586 587 def compute_mask(self, inputs, mask): 588 if isinstance(mask, list): 589 mask = mask[0] 590 if self.return_sequences: 591 if not self.merge_mode: 592 output_mask = [mask, mask] 593 else: 594 output_mask = mask 595 else: 596 output_mask = [None, None] if not self.merge_mode else None 597 598 if self.return_state: 599 states = self.forward_layer.states 600 state_mask = [None for _ in states] 601 if isinstance(output_mask, list): 602 return output_mask + state_mask * 2 603 return [output_mask] + state_mask * 2 604 return output_mask 605 606 @property 607 def constraints(self): 608 constraints = {} 609 if hasattr(self.forward_layer, 'constraints'): 610 constraints.update(self.forward_layer.constraints) 611 constraints.update(self.backward_layer.constraints) 612 return constraints 613 614 def get_config(self): 615 config = {'merge_mode': self.merge_mode} 616 if self._num_constants is not None: 617 config['num_constants'] = self._num_constants 618 base_config = super(Bidirectional, self).get_config() 619 return dict(list(base_config.items()) + list(config.items())) 620 621 @classmethod 622 def from_config(cls, config, custom_objects=None): 623 num_constants = config.pop('num_constants', None) 624 layer = super(Bidirectional, cls).from_config(config, 625 custom_objects=custom_objects) 626 layer._num_constants = num_constants 627 return layer 628