1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Wrapper layers: layers that augment the functionality of another layer.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import copy
23
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.keras import backend as K
26from tensorflow.python.keras.engine.base_layer import Layer
27from tensorflow.python.keras.engine.input_spec import InputSpec
28from tensorflow.python.keras.layers.recurrent import _standardize_args
29from tensorflow.python.keras.utils import generic_utils
30from tensorflow.python.keras.utils import layer_utils
31from tensorflow.python.keras.utils import tf_utils
32from tensorflow.python.ops import array_ops
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import keras_export
35
36
37@keras_export('keras.layers.Wrapper')
38class Wrapper(Layer):
39  """Abstract wrapper base class.
40
41  Wrappers take another layer and augment it in various ways.
42  Do not use this class as a layer, it is only an abstract base class.
43  Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
44
45  Arguments:
46    layer: The layer to be wrapped.
47  """
48
49  def __init__(self, layer, **kwargs):
50    assert isinstance(layer, Layer)
51    self.layer = layer
52    # Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
53    # the inner layer has update ops that depend on its inputs (as opposed
54    # to the inputs to the Wrapper layer).
55    self._input_map = {}
56    super(Wrapper, self).__init__(**kwargs)
57
58  def build(self, input_shape=None):
59    self.built = True
60
61  @property
62  def activity_regularizer(self):
63    if hasattr(self.layer, 'activity_regularizer'):
64      return self.layer.activity_regularizer
65    else:
66      return None
67
68  def get_config(self):
69    config = {
70        'layer': {
71            'class_name': self.layer.__class__.__name__,
72            'config': self.layer.get_config()
73        }
74    }
75    base_config = super(Wrapper, self).get_config()
76    return dict(list(base_config.items()) + list(config.items()))
77
78  @classmethod
79  def from_config(cls, config, custom_objects=None):
80    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
81    layer = deserialize_layer(
82        config.pop('layer'), custom_objects=custom_objects)
83    return cls(layer, **config)
84
85
86@keras_export('keras.layers.TimeDistributed')
87class TimeDistributed(Wrapper):
88  """This wrapper allows to apply a layer to every temporal slice of an input.
89
90  The input should be at least 3D, and the dimension of index one
91  will be considered to be the temporal dimension.
92
93  Consider a batch of 32 samples,
94  where each sample is a sequence of 10 vectors of 16 dimensions.
95  The batch input shape of the layer is then `(32, 10, 16)`,
96  and the `input_shape`, not including the samples dimension, is `(10, 16)`.
97
98  You can then use `TimeDistributed` to apply a `Dense` layer
99  to each of the 10 timesteps, independently:
100
101  ```python
102  # as the first layer in a model
103  model = Sequential()
104  model.add(TimeDistributed(Dense(8), input_shape=(10, 16)))
105  # now model.output_shape == (None, 10, 8)
106  ```
107
108  The output will then have shape `(32, 10, 8)`.
109
110  In subsequent layers, there is no need for the `input_shape`:
111
112  ```python
113  model.add(TimeDistributed(Dense(32)))
114  # now model.output_shape == (None, 10, 32)
115  ```
116
117  The output will then have shape `(32, 10, 32)`.
118
119  `TimeDistributed` can be used with arbitrary layers, not just `Dense`,
120  for instance with a `Conv2D` layer:
121
122  ```python
123  model = Sequential()
124  model.add(TimeDistributed(Conv2D(64, (3, 3)),
125                            input_shape=(10, 299, 299, 3)))
126  ```
127
128  Arguments:
129    layer: a layer instance.
130
131  Call arguments:
132    inputs: Input tensor.
133    training: Python boolean indicating whether the layer should behave in
134      training mode or in inference mode. This argument is passed to the
135      wrapped layer (only if the layer supports this argument).
136    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
137      a given timestep should be masked. This argument is passed to the
138      wrapped layer (only if the layer supports this argument).
139
140  Raises:
141    ValueError: If not initialized with a `Layer` instance.
142  """
143
144  def __init__(self, layer, **kwargs):
145    if not isinstance(layer, Layer):
146      raise ValueError(
147          'Please initialize `TimeDistributed` layer with a '
148          '`Layer` instance. You passed: {input}'.format(input=layer))
149    super(TimeDistributed, self).__init__(layer, **kwargs)
150    self.supports_masking = True
151
152    # It is safe to use the fast, reshape-based approach with all of our
153    # built-in Layers.
154    self._always_use_reshape = (
155        layer_utils.is_builtin_layer(layer) and
156        not getattr(layer, 'stateful', False))
157
158  def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
159    """Finds non-specific dimensions in the static shapes.
160
161    The static shapes are replaced with the corresponding dynamic shapes of the
162    tensor.
163
164    Arguments:
165      init_tuple: a tuple, the first part of the output shape
166      tensor: the tensor from which to get the (static and dynamic) shapes
167        as the last part of the output shape
168      start_idx: int, which indicate the first dimension to take from
169        the static shape of the tensor
170      int_shape: an alternative static shape to take as the last part
171        of the output shape
172
173    Returns:
174      The new int_shape with the first part from init_tuple
175      and the last part from either `int_shape` (if provided)
176      or `tensor.shape`, where every `None` is replaced by
177      the corresponding dimension from `tf.shape(tensor)`.
178    """
179    # replace all None in int_shape by K.shape
180    if int_shape is None:
181      int_shape = K.int_shape(tensor)[start_idx:]
182    if not any(not s for s in int_shape):
183      return init_tuple + tuple(int_shape)
184    shape = K.shape(tensor)
185    int_shape = list(int_shape)
186    for i, s in enumerate(int_shape):
187      if not s:
188        int_shape[i] = shape[start_idx + i]
189    return init_tuple + tuple(int_shape)
190
191  def build(self, input_shape):
192    input_shape = tensor_shape.TensorShape(input_shape).as_list()
193    if len(input_shape) < 3:
194      raise ValueError(
195          '`TimeDistributed` Layer should be passed an `input_shape ` '
196          'with at least 3 dimensions, received: ' + str(input_shape))
197    # Don't enforce the batch or time dimension.
198    self.input_spec = InputSpec(shape=[None, None] + input_shape[2:])
199    child_input_shape = [input_shape[0]] + input_shape[2:]
200    if not self.layer.built:
201      # The base layer class calls a conversion function on the input shape to
202      # convert it to a TensorShape. The conversion function requires a
203      # tuple which is why we cast the shape.
204      self.layer.build(tuple(child_input_shape))
205      self.layer.built = True
206    super(TimeDistributed, self).build()
207    self.built = True
208
209  def compute_output_shape(self, input_shape):
210    input_shape = tensor_shape.TensorShape(input_shape).as_list()
211    child_input_shape = tensor_shape.TensorShape([input_shape[0]] +
212                                                 input_shape[2:])
213    child_output_shape = self.layer.compute_output_shape(
214        child_input_shape).as_list()
215    timesteps = input_shape[1]
216    return tensor_shape.TensorShape([child_output_shape[0], timesteps] +
217                                    child_output_shape[1:])
218
219  def call(self, inputs, training=None, mask=None):
220    kwargs = {}
221    if generic_utils.has_arg(self.layer.call, 'training'):
222      kwargs['training'] = training
223
224    input_shape = K.int_shape(inputs)
225    if input_shape[0] and not self._always_use_reshape:
226      # batch size matters, use rnn-based implementation
227      def step(x, _):
228        output = self.layer.call(x, **kwargs)
229        return output, []
230
231      _, outputs, _ = K.rnn(
232          step,
233          inputs,
234          initial_states=[],
235          input_length=input_shape[1],
236          unroll=False)
237      y = outputs
238    else:
239      # No batch size specified, therefore the layer will be able
240      # to process batches of any size.
241      # We can go with reshape-based implementation for performance.
242      input_length = input_shape[1]
243      if not input_length:
244        input_length = array_ops.shape(inputs)[1]
245      inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
246      # Shape: (num_samples * timesteps, ...). And track the
247      # transformation in self._input_map.
248      input_uid = generic_utils.object_list_uid(inputs)
249      inputs = array_ops.reshape(inputs, inner_input_shape)
250      self._input_map[input_uid] = inputs
251      # (num_samples * timesteps, ...)
252      if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
253        inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
254        kwargs['mask'] = K.reshape(mask, inner_mask_shape)
255      y = self.layer.call(inputs, **kwargs)
256      # Shape: (num_samples, timesteps, ...)
257      output_shape = self.compute_output_shape(input_shape).as_list()
258      output_shape = self._get_shape_tuple(
259          (-1, input_length), y, 1, output_shape[2:])
260      y = array_ops.reshape(y, output_shape)
261
262    # Apply activity regularizer if any:
263    if (hasattr(self.layer, 'activity_regularizer') and
264        self.layer.activity_regularizer is not None):
265      regularization_loss = self.layer.activity_regularizer(y)
266      self.add_loss(regularization_loss, inputs)
267    return y
268
269  def compute_mask(self, inputs, mask=None):
270    """Computes an output mask tensor for Embedding layer.
271
272    This is based on the inputs, mask, and the inner layer.
273    If batch size is specified:
274    Simply return the input `mask`. (An rnn-based implementation with
275    more than one rnn inputs is required but not supported in tf.keras yet.)
276    Otherwise we call `compute_mask` of the inner layer at each time step.
277    If the output mask at each time step is not `None`:
278    (E.g., inner layer is Masking or RNN)
279    Concatenate all of them and return the concatenation.
280    If the output mask at each time step is `None` and the input mask is not
281    `None`:(E.g., inner layer is Dense)
282    Reduce the input_mask to 2 dimensions and return it.
283    Otherwise (both the output mask and the input mask are `None`):
284    (E.g., `mask` is not used at all)
285    Return `None`.
286
287    Arguments:
288      inputs: Tensor with shape [batch size, timesteps, ...] indicating the
289        input to TimeDistributed. If static shape information is available for
290        "batch size", `mask` is returned unmodified.
291      mask: Either None (indicating no masking) or a Tensor indicating the
292        input mask for TimeDistributed. The shape can be static or dynamic.
293
294    Returns:
295      Either None (no masking), or a [batch size, timesteps, ...] Tensor with
296      an output mask for the TimeDistributed layer with the shape beyond the
297      second dimension being the value of the input mask shape(if the computed
298      output mask is none), an output mask with the shape beyond the first
299      dimension being the value of the mask shape(if mask is not None) or
300      output mask with the shape beyond the first dimension being the
301      value of the computed output shape.
302
303    """
304    # cases need to call the layer.compute_mask when input_mask is None:
305    # Masking layer and Embedding layer with mask_zero
306    input_shape = K.int_shape(inputs)
307    if input_shape[0]:
308      # batch size matters, we currently do not handle mask explicitly
309      return mask
310    inner_mask = mask
311    if inner_mask is not None:
312      inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
313      inner_mask = K.reshape(inner_mask, inner_mask_shape)
314    input_uid = generic_utils.object_list_uid(inputs)
315    inner_inputs = self._input_map.get(input_uid, inputs)
316    output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
317    if output_mask is None:
318      if mask is None:
319        return None
320      # input_mask is not None, and output_mask is None:
321      # we should return a not-None mask
322      output_mask = mask
323      for _ in range(2, len(K.int_shape(mask))):
324        output_mask = K.any(output_mask, axis=-1)
325    else:
326      # output_mask is not None. We need to reshape it
327      input_length = input_shape[1]
328      if not input_length:
329        input_length = K.shape(inputs)[1]
330      output_mask_int_shape = K.int_shape(output_mask)
331      if output_mask_int_shape is None:
332        # if the output_mask does not have a static shape,
333        # its shape must be the same as mask's
334        if mask is not None:
335          output_mask_int_shape = K.int_shape(mask)
336        else:
337          output_mask_int_shape = K.compute_output_shape(input_shape)[:-1]
338      output_mask_shape = self._get_shape_tuple(
339          (-1, input_length), output_mask, 1, output_mask_int_shape[1:])
340      output_mask = K.reshape(output_mask, output_mask_shape)
341    return output_mask
342
343
344@keras_export('keras.layers.Bidirectional')
345class Bidirectional(Wrapper):
346  """Bidirectional wrapper for RNNs.
347
348  Arguments:
349    layer: `Recurrent` instance.
350    merge_mode: Mode by which outputs of the
351      forward and backward RNNs will be combined.
352      One of {'sum', 'mul', 'concat', 'ave', None}.
353      If None, the outputs will not be combined,
354      they will be returned as a list.
355
356  Call arguments:
357    The call arguments for this layer are the same as those of the wrapped RNN
358      layer.
359
360  Raises:
361    ValueError: If not initialized with a `Layer` instance or
362      In case of invalid `merge_mode` argument.
363
364  Examples:
365
366  ```python
367  model = Sequential()
368  model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5,
369  10)))
370  model.add(Bidirectional(LSTM(10)))
371  model.add(Dense(5))
372  model.add(Activation('softmax'))
373  model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
374  ```
375  """
376
377  def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
378    if not isinstance(layer, Layer):
379      raise ValueError(
380          'Please initialize `Bidirectional` layer with a '
381          '`Layer` instance. You passed: {input}'.format(input=layer))
382    if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]:
383      raise ValueError('Invalid merge mode. '
384                       'Merge mode should be one of '
385                       '{"sum", "mul", "ave", "concat", None}')
386    if getattr(layer, 'zero_output_for_mask', None) is not None:
387      # Force the zero_output_for_mask to be True if returning sequences.
388      layer.zero_output_for_mask = layer.return_sequences
389
390    self.forward_layer = copy.copy(layer)
391    config = layer.get_config()
392    config['go_backwards'] = not config['go_backwards']
393    self.backward_layer = layer.__class__.from_config(config)
394    self.forward_layer._name = 'forward_' + self.forward_layer.name
395    self.backward_layer._name = 'backward_' + self.backward_layer.name
396    self.merge_mode = merge_mode
397    if weights:
398      nw = len(weights)
399      self.forward_layer.initial_weights = weights[:nw // 2]
400      self.backward_layer.initial_weights = weights[nw // 2:]
401    self.stateful = layer.stateful
402    self.return_sequences = layer.return_sequences
403    self.return_state = layer.return_state
404    self.supports_masking = True
405    self._trainable = True
406    self._num_constants = None
407    # We don't want to track `layer` since we're already tracking the two copies
408    # of it we actually run.
409    self._setattr_tracking = False
410    super(Bidirectional, self).__init__(layer, **kwargs)
411    self._setattr_tracking = True
412    self.input_spec = layer.input_spec
413
414  @tf_utils.shape_type_conversion
415  def compute_output_shape(self, input_shape):
416    output_shape = tuple(self.forward_layer.compute_output_shape(
417        input_shape).as_list())
418    if self.return_state:
419      state_shape = output_shape[1:]
420      output_shape = output_shape[0]
421
422    if self.merge_mode == 'concat':
423      output_shape = list(output_shape)
424      output_shape[-1] *= 2
425      output_shape = tuple(output_shape)
426    elif self.merge_mode is None:
427      output_shape = [output_shape, copy.copy(output_shape)]
428
429    if self.return_state:
430      if self.merge_mode is None:
431        return output_shape + state_shape + copy.copy(state_shape)
432      return [output_shape] + state_shape + copy.copy(state_shape)
433    return output_shape
434
435  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
436    """`Bidirectional.__call__` implements the same API as the wrapped `RNN`."""
437    inputs, initial_state, constants = _standardize_args(
438        inputs, initial_state, constants, self._num_constants)
439
440    if isinstance(inputs, list):
441      if len(inputs) > 1:
442        initial_state = inputs[1:]
443      inputs = inputs[0]
444
445    if initial_state is None and constants is None:
446      return super(Bidirectional, self).__call__(inputs, **kwargs)
447
448    # Applies the same workaround as in `RNN.__call__`
449    additional_inputs = []
450    additional_specs = []
451    if initial_state is not None:
452      # Check if `initial_state` can be splitted into half
453      num_states = len(initial_state)
454      if num_states % 2 > 0:
455        raise ValueError(
456            'When passing `initial_state` to a Bidirectional RNN, '
457            'the state should be a list containing the states of '
458            'the underlying RNNs. '
459            'Found: ' + str(initial_state))
460
461      kwargs['initial_state'] = initial_state
462      additional_inputs += initial_state
463      state_specs = [InputSpec(shape=K.int_shape(state))
464                     for state in initial_state]
465      self.forward_layer.state_spec = state_specs[:num_states // 2]
466      self.backward_layer.state_spec = state_specs[num_states // 2:]
467      additional_specs += state_specs
468    if constants is not None:
469      kwargs['constants'] = constants
470      additional_inputs += constants
471      constants_spec = [InputSpec(shape=K.int_shape(constant))
472                        for constant in constants]
473      self.forward_layer.constants_spec = constants_spec
474      self.backward_layer.constants_spec = constants_spec
475      additional_specs += constants_spec
476
477      self._num_constants = len(constants)
478      self.forward_layer._num_constants = self._num_constants
479      self.backward_layer._num_constants = self._num_constants
480
481    is_keras_tensor = K.is_keras_tensor(additional_inputs[0])
482    for tensor in additional_inputs:
483      if K.is_keras_tensor(tensor) != is_keras_tensor:
484        raise ValueError('The initial state of a Bidirectional'
485                         ' layer cannot be specified with a mix of'
486                         ' Keras tensors and non-Keras tensors'
487                         ' (a "Keras tensor" is a tensor that was'
488                         ' returned by a Keras layer, or by `Input`)')
489
490    if is_keras_tensor:
491      # Compute the full input spec, including state
492      full_input = [inputs] + additional_inputs
493      # The original input_spec is None since there could be a nested tensor
494      # input. Update the input_spec to match the inputs.
495      full_input_spec = [None for _ in range(len(nest.flatten(inputs)))
496                        ] + additional_specs
497
498      # Perform the call with temporarily replaced input_spec
499      original_input_spec = self.input_spec
500      self.input_spec = full_input_spec
501      output = super(Bidirectional, self).__call__(full_input, **kwargs)
502      self.input_spec = original_input_spec
503      return output
504    else:
505      return super(Bidirectional, self).__call__(inputs, **kwargs)
506
507  def call(self,
508           inputs,
509           training=None,
510           mask=None,
511           initial_state=None,
512           constants=None):
513    """`Bidirectional.call` implements the same API as the wrapped `RNN`."""
514    kwargs = {}
515    if generic_utils.has_arg(self.layer.call, 'training'):
516      kwargs['training'] = training
517    if generic_utils.has_arg(self.layer.call, 'mask'):
518      kwargs['mask'] = mask
519    if generic_utils.has_arg(self.layer.call, 'constants'):
520      kwargs['constants'] = constants
521
522    if initial_state is not None and generic_utils.has_arg(
523        self.layer.call, 'initial_state'):
524      forward_inputs = [inputs[0]]
525      backward_inputs = [inputs[0]]
526      pivot = len(initial_state) // 2 + 1
527      # add forward initial state
528      forward_state = inputs[1:pivot]
529      forward_inputs += forward_state
530      if self._num_constants is None:
531        # add backward initial state
532        backward_state = inputs[pivot:]
533        backward_inputs += backward_state
534      else:
535        # add backward initial state
536        backward_state = inputs[pivot:-self._num_constants]
537        backward_inputs += backward_state
538        # add constants for forward and backward layers
539        forward_inputs += inputs[-self._num_constants:]
540        backward_inputs += inputs[-self._num_constants:]
541      y = self.forward_layer.call(forward_inputs,
542                                  initial_state=forward_state, **kwargs)
543      y_rev = self.backward_layer.call(backward_inputs,
544                                       initial_state=backward_state, **kwargs)
545    else:
546      y = self.forward_layer.call(inputs, **kwargs)
547      y_rev = self.backward_layer.call(inputs, **kwargs)
548
549    if self.return_state:
550      states = y[1:] + y_rev[1:]
551      y = y[0]
552      y_rev = y_rev[0]
553
554    if self.return_sequences:
555      y_rev = K.reverse(y_rev, 1)
556    if self.merge_mode == 'concat':
557      output = K.concatenate([y, y_rev])
558    elif self.merge_mode == 'sum':
559      output = y + y_rev
560    elif self.merge_mode == 'ave':
561      output = (y + y_rev) / 2
562    elif self.merge_mode == 'mul':
563      output = y * y_rev
564    elif self.merge_mode is None:
565      output = [y, y_rev]
566    else:
567      raise ValueError(
568          'Unrecognized value for `merge_mode`: %s' % (self.merge_mode))
569
570    if self.return_state:
571      if self.merge_mode is None:
572        return output + states
573      return [output] + states
574    return output
575
576  def reset_states(self):
577    self.forward_layer.reset_states()
578    self.backward_layer.reset_states()
579
580  def build(self, input_shape):
581    with K.name_scope(self.forward_layer.name):
582      self.forward_layer.build(input_shape)
583    with K.name_scope(self.backward_layer.name):
584      self.backward_layer.build(input_shape)
585    self.built = True
586
587  def compute_mask(self, inputs, mask):
588    if isinstance(mask, list):
589      mask = mask[0]
590    if self.return_sequences:
591      if not self.merge_mode:
592        output_mask = [mask, mask]
593      else:
594        output_mask = mask
595    else:
596      output_mask = [None, None] if not self.merge_mode else None
597
598    if self.return_state:
599      states = self.forward_layer.states
600      state_mask = [None for _ in states]
601      if isinstance(output_mask, list):
602        return output_mask + state_mask * 2
603      return [output_mask] + state_mask * 2
604    return output_mask
605
606  @property
607  def constraints(self):
608    constraints = {}
609    if hasattr(self.forward_layer, 'constraints'):
610      constraints.update(self.forward_layer.constraints)
611      constraints.update(self.backward_layer.constraints)
612    return constraints
613
614  def get_config(self):
615    config = {'merge_mode': self.merge_mode}
616    if self._num_constants is not None:
617      config['num_constants'] = self._num_constants
618    base_config = super(Bidirectional, self).get_config()
619    return dict(list(base_config.items()) + list(config.items()))
620
621  @classmethod
622  def from_config(cls, config, custom_objects=None):
623    num_constants = config.pop('num_constants', None)
624    layer = super(Bidirectional, cls).from_config(config,
625                                                  custom_objects=custom_objects)
626    layer._num_constants = num_constants
627    return layer
628