1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for layer wrappers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22
23import numpy as np
24
25from tensorflow.python import keras
26from tensorflow.python.eager import context
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import test_util as tf_test_util
29from tensorflow.python.platform import test
30from tensorflow.python.training.tracking import object_identity
31from tensorflow.python.training.tracking import util as trackable_util
32
33
34class _RNNCellWithConstants(keras.layers.Layer):
35
36  def __init__(self, units, **kwargs):
37    self.units = units
38    self.state_size = units
39    super(_RNNCellWithConstants, self).__init__(**kwargs)
40
41  def build(self, input_shape):
42    [input_shape, constant_shape] = input_shape
43
44    self.input_kernel = self.add_weight(
45        shape=(input_shape[-1], self.units),
46        initializer='uniform',
47        name='kernel')
48    self.recurrent_kernel = self.add_weight(
49        shape=(self.units, self.units),
50        initializer='uniform',
51        name='recurrent_kernel')
52    self.constant_kernel = self.add_weight(
53        shape=(constant_shape[-1], self.units),
54        initializer='uniform',
55        name='constant_kernel')
56    self.built = True
57
58  def call(self, inputs, states, constants):
59    [prev_output] = states
60    [constant] = constants
61    h_input = keras.backend.dot(inputs, self.input_kernel)
62    h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
63    h_const = keras.backend.dot(constant, self.constant_kernel)
64    output = h_input + h_state + h_const
65    return output, [output]
66
67  def get_config(self):
68    config = {'units': self.units}
69    base_config = super(_RNNCellWithConstants, self).get_config()
70    return dict(list(base_config.items()) + list(config.items()))
71
72
73class TimeDistributedTest(test.TestCase):
74
75  @tf_test_util.run_in_graph_and_eager_modes
76  def test_timedistributed_dense(self):
77    model = keras.models.Sequential()
78    model.add(
79        keras.layers.TimeDistributed(
80            keras.layers.Dense(2), input_shape=(3, 4)))
81    model.compile(optimizer='rmsprop', loss='mse')
82    model.fit(
83        np.random.random((10, 3, 4)),
84        np.random.random((10, 3, 2)),
85        epochs=1,
86        batch_size=10)
87
88    # test config
89    model.get_config()
90
91    # check whether the model variables are present in the
92    # trackable list of objects
93    checkpointed_objects = object_identity.ObjectIdentitySet(
94        trackable_util.list_objects(model))
95    for v in model.variables:
96      self.assertIn(v, checkpointed_objects)
97
98  def test_timedistributed_static_batch_size(self):
99    model = keras.models.Sequential()
100    model.add(
101        keras.layers.TimeDistributed(
102            keras.layers.Dense(2), input_shape=(3, 4), batch_size=10))
103    model.compile(optimizer='rmsprop', loss='mse')
104    model.fit(
105        np.random.random((10, 3, 4)),
106        np.random.random((10, 3, 2)),
107        epochs=1,
108        batch_size=10)
109
110  def test_timedistributed_invalid_init(self):
111    x = constant_op.constant(np.zeros((1, 1)).astype('float32'))
112    with self.assertRaisesRegexp(
113        ValueError,
114        'Please initialize `TimeDistributed` layer with a `Layer` instance.'):
115      keras.layers.TimeDistributed(x)
116
117  def test_timedistributed_conv2d(self):
118    with self.cached_session():
119      model = keras.models.Sequential()
120      model.add(
121          keras.layers.TimeDistributed(
122              keras.layers.Conv2D(5, (2, 2), padding='same'),
123              input_shape=(2, 4, 4, 3)))
124      model.add(keras.layers.Activation('relu'))
125      model.compile(optimizer='rmsprop', loss='mse')
126      model.train_on_batch(
127          np.random.random((1, 2, 4, 4, 3)), np.random.random((1, 2, 4, 4, 5)))
128
129      model = keras.models.model_from_json(model.to_json())
130      model.summary()
131
132  def test_timedistributed_stacked(self):
133    with self.cached_session():
134      model = keras.models.Sequential()
135      model.add(
136          keras.layers.TimeDistributed(
137              keras.layers.Dense(2), input_shape=(3, 4)))
138      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
139      model.add(keras.layers.Activation('relu'))
140      model.compile(optimizer='rmsprop', loss='mse')
141
142      model.fit(
143          np.random.random((10, 3, 4)),
144          np.random.random((10, 3, 3)),
145          epochs=1,
146          batch_size=10)
147
148  def test_regularizers(self):
149    with self.cached_session():
150      model = keras.models.Sequential()
151      model.add(
152          keras.layers.TimeDistributed(
153              keras.layers.Dense(2, kernel_regularizer='l1'),
154              input_shape=(3, 4)))
155      model.add(keras.layers.Activation('relu'))
156      model.compile(optimizer='rmsprop', loss='mse')
157      self.assertEqual(len(model.losses), 1)
158
159  def test_TimeDistributed_batchnorm(self):
160    with self.cached_session():
161      # test that wrapped BN updates still work.
162      model = keras.models.Sequential()
163      model.add(keras.layers.TimeDistributed(
164          keras.layers.BatchNormalization(center=True, scale=True),
165          name='bn',
166          input_shape=(10, 2)))
167      model.compile(optimizer='rmsprop', loss='mse')
168      # Assert that mean and variance are 0 and 1.
169      td = model.layers[0]
170      self.assertAllClose(td.get_weights()[2], np.array([0, 0]))
171      assert np.array_equal(td.get_weights()[3], np.array([1, 1]))
172      # Train
173      model.train_on_batch(np.random.normal(loc=2, scale=2, size=(1, 10, 2)),
174                           np.broadcast_to(np.array([0, 1]), (1, 10, 2)))
175      # Assert that mean and variance changed.
176      assert not np.array_equal(td.get_weights()[2], np.array([0, 0]))
177      assert not np.array_equal(td.get_weights()[3], np.array([1, 1]))
178      # Verify input_map has one mapping from inputs to reshaped inputs.
179      self.assertEqual(len(td._input_map.keys()), 1)
180
181  def test_TimeDistributed_trainable(self):
182    # test layers that need learning_phase to be set
183    x = keras.layers.Input(shape=(3, 2))
184    layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization())
185    _ = layer(x)
186    self.assertEqual(len(layer.updates), 2)
187    self.assertEqual(len(layer.trainable_weights), 2)
188    layer.trainable = False
189    assert not layer.updates
190    assert not layer.trainable_weights
191    layer.trainable = True
192    assert len(layer.updates) == 2
193    assert len(layer.trainable_weights) == 2
194
195  def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self):
196    with self.cached_session():
197      # test with unspecified shape and Embeddings with mask_zero
198      model = keras.models.Sequential()
199      model.add(keras.layers.TimeDistributed(
200          keras.layers.Embedding(5, 6, mask_zero=True),
201          input_shape=(None, None)))  # N by t_1 by t_2 by 6
202      model.add(keras.layers.TimeDistributed(
203          keras.layers.SimpleRNN(7, return_sequences=True)))
204      model.add(keras.layers.TimeDistributed(
205          keras.layers.SimpleRNN(8, return_sequences=False)))
206      model.add(keras.layers.SimpleRNN(1, return_sequences=False))
207      model.compile(optimizer='rmsprop', loss='mse')
208      model_input = np.random.randint(low=1, high=5, size=(10, 3, 4),
209                                      dtype='int32')
210      for i in range(4):
211        model_input[i, i:, i:] = 0
212      model.fit(model_input,
213                np.random.random((10, 1)), epochs=1, batch_size=10)
214      mask_outputs = [model.layers[0].compute_mask(model.input)]
215      for layer in model.layers[1:]:
216        mask_outputs.append(layer.compute_mask(layer.input, mask_outputs[-1]))
217      func = keras.backend.function([model.input], mask_outputs[:-1])
218      mask_outputs_val = func([model_input])
219      ref_mask_val_0 = model_input > 0         # embedding layer
220      ref_mask_val_1 = ref_mask_val_0          # first RNN layer
221      ref_mask_val_2 = np.any(ref_mask_val_1, axis=-1)     # second RNN layer
222      ref_mask_val = [ref_mask_val_0, ref_mask_val_1, ref_mask_val_2]
223      for i in range(3):
224        self.assertAllEqual(mask_outputs_val[i], ref_mask_val[i])
225      self.assertIs(mask_outputs[-1], None)  # final layer
226
227  def test_TimeDistributed_with_masking_layer(self):
228    with self.cached_session():
229      # test with Masking layer
230      model = keras.models.Sequential()
231      model.add(keras.layers.TimeDistributed(keras.layers.Masking(
232          mask_value=0.,), input_shape=(None, 4)))
233      model.add(keras.layers.TimeDistributed(keras.layers.Dense(5)))
234      model.compile(optimizer='rmsprop', loss='mse')
235      model_input = np.random.randint(low=1, high=5, size=(10, 3, 4))
236      for i in range(4):
237        model_input[i, i:, :] = 0.
238      model.compile(optimizer='rmsprop', loss='mse')
239      model.fit(model_input,
240                np.random.random((10, 3, 5)), epochs=1, batch_size=6)
241      mask_outputs = [model.layers[0].compute_mask(model.input)]
242      mask_outputs += [model.layers[1].compute_mask(model.layers[1].input,
243                                                    mask_outputs[-1])]
244      func = keras.backend.function([model.input], mask_outputs)
245      mask_outputs_val = func([model_input])
246      self.assertEqual((mask_outputs_val[0]).all(),
247                       model_input.all())
248      self.assertEqual((mask_outputs_val[1]).all(),
249                       model_input.all())
250
251  def test_TimeDistributed_with_different_time_shapes(self):
252    time_dist = keras.layers.TimeDistributed(keras.layers.Dense(5))
253    ph_1 = keras.backend.placeholder(shape=(None, 10, 13))
254    out_1 = time_dist(ph_1)
255    self.assertEqual(out_1.shape.as_list(), [None, 10, 5])
256
257    ph_2 = keras.backend.placeholder(shape=(None, 1, 13))
258    out_2 = time_dist(ph_2)
259    self.assertEqual(out_2.shape.as_list(), [None, 1, 5])
260
261    ph_3 = keras.backend.placeholder(shape=(None, 1, 18))
262    with self.assertRaisesRegexp(ValueError, 'is incompatible with layer'):
263      time_dist(ph_3)
264
265  def test_TimeDistributed_with_invalid_dimensions(self):
266    time_dist = keras.layers.TimeDistributed(keras.layers.Dense(5))
267    ph = keras.backend.placeholder(shape=(None, 10))
268    with self.assertRaisesRegexp(
269        ValueError,
270        '`TimeDistributed` Layer should be passed an `input_shape `'):
271      time_dist(ph)
272
273  @tf_test_util.run_in_graph_and_eager_modes
274  def test_TimeDistributed_reshape(self):
275
276    class NoReshapeLayer(keras.layers.Layer):
277
278      def call(self, inputs):
279        return inputs
280
281    # Built-in layers that aren't stateful use the reshape implementation.
282    td1 = keras.layers.TimeDistributed(keras.layers.Dense(5))
283    self.assertTrue(td1._always_use_reshape)
284
285    # Built-in layers that are stateful don't use the reshape implementation.
286    td2 = keras.layers.TimeDistributed(
287        keras.layers.RNN(keras.layers.SimpleRNNCell(10), stateful=True))
288    self.assertFalse(td2._always_use_reshape)
289
290    # Custom layers are not whitelisted for the fast reshape implementation.
291    td3 = keras.layers.TimeDistributed(NoReshapeLayer())
292    self.assertFalse(td3._always_use_reshape)
293
294
295class BidirectionalTest(test.TestCase):
296
297  def test_bidirectional(self):
298    rnn = keras.layers.SimpleRNN
299    samples = 2
300    dim = 2
301    timesteps = 2
302    output_dim = 2
303    with self.cached_session():
304      for mode in ['sum', 'concat', 'ave', 'mul']:
305        x = np.random.random((samples, timesteps, dim))
306        target_dim = 2 * output_dim if mode == 'concat' else output_dim
307        y = np.random.random((samples, target_dim))
308
309        # test with Sequential model
310        model = keras.models.Sequential()
311        model.add(
312            keras.layers.Bidirectional(
313                rnn(output_dim), merge_mode=mode, input_shape=(timesteps, dim)))
314        model.compile(optimizer='rmsprop', loss='mse')
315        model.fit(x, y, epochs=1, batch_size=1)
316
317        # check whether the model variables are present in the
318        # trackable list of objects
319        checkpointed_objects = object_identity.ObjectIdentitySet(
320            trackable_util.list_objects(model))
321        for v in model.variables:
322          self.assertIn(v, checkpointed_objects)
323
324        # test compute output shape
325        ref_shape = model.layers[-1].output.get_shape()
326        shape = model.layers[-1].compute_output_shape(
327            (None, timesteps, dim))
328        self.assertListEqual(shape.as_list(), ref_shape.as_list())
329
330        # test config
331        model.get_config()
332        model = keras.models.model_from_json(model.to_json())
333        model.summary()
334
335  def test_bidirectional_invalid_init(self):
336    x = constant_op.constant(np.zeros((1, 1)).astype('float32'))
337    with self.assertRaisesRegexp(
338        ValueError,
339        'Please initialize `Bidirectional` layer with a `Layer` instance.'):
340      keras.layers.Bidirectional(x)
341
342  def test_bidirectional_weight_loading(self):
343    rnn = keras.layers.SimpleRNN
344    samples = 2
345    dim = 2
346    timesteps = 2
347    output_dim = 2
348    with self.cached_session():
349      x = np.random.random((samples, timesteps, dim))
350      model = keras.models.Sequential()
351      model.add(
352          keras.layers.Bidirectional(
353              rnn(output_dim), input_shape=(timesteps, dim)))
354      y_ref = model.predict(x)
355      weights = model.layers[-1].get_weights()
356      model.layers[-1].set_weights(weights)
357      y = model.predict(x)
358      self.assertAllClose(y, y_ref)
359
360  def test_bidirectional_stacked(self):
361    # test stacked bidirectional layers
362    rnn = keras.layers.SimpleRNN
363    samples = 2
364    dim = 2
365    timesteps = 2
366    output_dim = 2
367    mode = 'sum'
368
369    with self.cached_session():
370      x = np.random.random((samples, timesteps, dim))
371      target_dim = 2 * output_dim if mode == 'concat' else output_dim
372      y = np.random.random((samples, target_dim))
373
374      model = keras.models.Sequential()
375      model.add(
376          keras.layers.Bidirectional(
377              rnn(output_dim, return_sequences=True),
378              merge_mode=mode,
379              input_shape=(timesteps, dim)))
380      model.add(keras.layers.Bidirectional(rnn(output_dim), merge_mode=mode))
381      model.compile(loss='mse', optimizer='sgd')
382      model.fit(x, y, epochs=1, batch_size=1)
383
384      # test with functional API
385      inputs = keras.layers.Input((timesteps, dim))
386      output = keras.layers.Bidirectional(
387          rnn(output_dim), merge_mode=mode)(inputs)
388      model = keras.models.Model(inputs, output)
389      model.compile(loss='mse', optimizer='sgd')
390      model.fit(x, y, epochs=1, batch_size=1)
391
392  def test_bidirectional_statefulness(self):
393    # Bidirectional and stateful
394    rnn = keras.layers.SimpleRNN
395    samples = 2
396    dim = 2
397    timesteps = 2
398    output_dim = 2
399    mode = 'sum'
400
401    with self.cached_session():
402      x = np.random.random((samples, timesteps, dim))
403      target_dim = 2 * output_dim if mode == 'concat' else output_dim
404      y = np.random.random((samples, target_dim))
405
406      inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
407      output = keras.layers.Bidirectional(
408          rnn(output_dim, stateful=True), merge_mode=mode)(inputs)
409      model = keras.models.Model(inputs, output)
410      model.compile(loss='mse', optimizer='sgd')
411      model.fit(x, y, epochs=1, batch_size=1)
412
413  def test_Bidirectional_merged_value(self):
414    rnn = keras.layers.LSTM
415    samples = 2
416    dim = 5
417    timesteps = 3
418    units = 3
419    x = [np.random.rand(samples, timesteps, dim)]
420
421    with self.cached_session():
422      for merge_mode in ['sum', 'mul', 'ave', 'concat', None]:
423        if merge_mode == 'sum':
424          merge_func = lambda y, y_rev: y + y_rev
425        elif merge_mode == 'mul':
426          merge_func = lambda y, y_rev: y * y_rev
427        elif merge_mode == 'ave':
428          merge_func = lambda y, y_rev: (y + y_rev) / 2
429        elif merge_mode == 'concat':
430          merge_func = lambda y, y_rev: np.concatenate((y, y_rev), axis=-1)
431        else:
432          merge_func = lambda y, y_rev: [y, y_rev]
433
434        # basic case
435        inputs = keras.Input((timesteps, dim))
436        layer = keras.layers.Bidirectional(
437            rnn(units, return_sequences=True), merge_mode=merge_mode)
438        f_merged = keras.backend.function([inputs], _to_list(layer(inputs)))
439        f_forward = keras.backend.function([inputs],
440                                           [layer.forward_layer(inputs)])
441        f_backward = keras.backend.function(
442            [inputs],
443            [keras.backend.reverse(layer.backward_layer(inputs), 1)])
444
445        y_merged = f_merged(x)
446        y_expected = _to_list(merge_func(f_forward(x)[0], f_backward(x)[0]))
447        assert len(y_merged) == len(y_expected)
448        for x1, x2 in zip(y_merged, y_expected):
449          self.assertAllClose(x1, x2, atol=1e-5)
450
451        # test return_state
452        inputs = keras.Input((timesteps, dim))
453        layer = keras.layers.Bidirectional(
454            rnn(units, return_state=True), merge_mode=merge_mode)
455        f_merged = keras.backend.function([inputs], layer(inputs))
456        f_forward = keras.backend.function([inputs],
457                                           layer.forward_layer(inputs))
458        f_backward = keras.backend.function([inputs],
459                                            layer.backward_layer(inputs))
460        n_states = len(layer.layer.states)
461
462        y_merged = f_merged(x)
463        y_forward = f_forward(x)
464        y_backward = f_backward(x)
465        y_expected = _to_list(merge_func(y_forward[0], y_backward[0]))
466        assert len(y_merged) == len(y_expected) + n_states * 2
467        for x1, x2 in zip(y_merged, y_expected):
468          self.assertAllClose(x1, x2, atol=1e-5)
469
470        y_merged = y_merged[-n_states * 2:]
471        y_forward = y_forward[-n_states:]
472        y_backward = y_backward[-n_states:]
473        for state_birnn, state_inner in zip(y_merged, y_forward + y_backward):
474          self.assertAllClose(state_birnn, state_inner, atol=1e-5)
475
476  def test_Bidirectional_dropout(self):
477    rnn = keras.layers.LSTM
478    samples = 2
479    dim = 5
480    timesteps = 3
481    units = 3
482    merge_mode = 'sum'
483    x = [np.random.rand(samples, timesteps, dim)]
484
485    with self.cached_session():
486      inputs = keras.Input((timesteps, dim))
487      wrapped = keras.layers.Bidirectional(
488          rnn(units, dropout=0.2, recurrent_dropout=0.2), merge_mode=merge_mode)
489      outputs = _to_list(wrapped(inputs, training=True))
490
491      inputs = keras.Input((timesteps, dim))
492      wrapped = keras.layers.Bidirectional(
493          rnn(units, dropout=0.2, return_state=True), merge_mode=merge_mode)
494      outputs = _to_list(wrapped(inputs))
495
496      model = keras.Model(inputs, outputs)
497      y1 = _to_list(model.predict(x))
498      y2 = _to_list(model.predict(x))
499      for x1, x2 in zip(y1, y2):
500        self.assertAllClose(x1, x2, atol=1e-5)
501
502  def test_Bidirectional_state_reuse(self):
503    rnn = keras.layers.LSTM
504    samples = 2
505    dim = 5
506    timesteps = 3
507    units = 3
508
509    with self.cached_session():
510      input1 = keras.layers.Input((timesteps, dim))
511      layer = keras.layers.Bidirectional(
512          rnn(units, return_state=True, return_sequences=True))
513      state = layer(input1)[1:]
514
515      # test passing invalid initial_state: passing a tensor
516      input2 = keras.layers.Input((timesteps, dim))
517      with self.assertRaises(ValueError):
518        output = keras.layers.Bidirectional(
519            rnn(units))(input2, initial_state=state[0])
520
521      # test valid usage: passing a list
522      output = keras.layers.Bidirectional(rnn(units))(input2,
523                                                      initial_state=state)
524      model = keras.models.Model([input1, input2], output)
525      assert len(model.layers) == 4
526      assert isinstance(model.layers[-1].input, list)
527      inputs = [np.random.rand(samples, timesteps, dim),
528                np.random.rand(samples, timesteps, dim)]
529      model.predict(inputs)
530
531  def test_Bidirectional_trainable(self):
532    # test layers that need learning_phase to be set
533    with self.cached_session():
534      x = keras.layers.Input(shape=(3, 2))
535      layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
536      _ = layer(x)
537      assert len(layer.trainable_weights) == 6
538      layer.trainable = False
539      assert not layer.trainable_weights
540      layer.trainable = True
541      assert len(layer.trainable_weights) == 6
542
543  def test_Bidirectional_updates(self):
544    if context.executing_eagerly():
545      self.skipTest('layer.updates is only available in graph mode.')
546
547    with self.cached_session():
548      x = keras.layers.Input(shape=(3, 2))
549      x_reachable_update = x * x
550      layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
551      _ = layer(x)
552      assert not layer.updates
553      assert not layer.get_updates_for(None)
554      assert not layer.get_updates_for(x)
555      layer.forward_layer.add_update(x_reachable_update, inputs=x)
556      layer.forward_layer.add_update(1, inputs=None)
557      layer.backward_layer.add_update(x_reachable_update, inputs=x)
558      layer.backward_layer.add_update(1, inputs=None)
559      assert len(layer.updates) == 4
560      assert len(layer.get_updates_for(None)) == 2
561      assert len(layer.get_updates_for(x)) == 2
562
563  def test_Bidirectional_losses(self):
564    with self.cached_session():
565      x = keras.layers.Input(shape=(3, 2))
566      x_reachable_loss = x * x
567      layer = keras.layers.Bidirectional(
568          keras.layers.SimpleRNN(
569              3, kernel_regularizer='l1', bias_regularizer='l1'))
570      _ = layer(x)
571      assert len(layer.losses) == 4
572      assert len(layer.get_losses_for(None)) == 4
573      assert not layer.get_losses_for(x)
574
575      # Create a random tensor that is not conditional on the inputs.
576      with keras.backend.get_graph().as_default():
577        const_tensor = constant_op.constant(1)
578
579      layer.forward_layer.add_loss(x_reachable_loss, inputs=x)
580      layer.forward_layer.add_loss(const_tensor, inputs=None)
581      layer.backward_layer.add_loss(x_reachable_loss, inputs=x)
582      layer.backward_layer.add_loss(const_tensor, inputs=None)
583      assert len(layer.losses) == 8
584      assert len(layer.get_losses_for(None)) == 6
585      assert len(layer.get_losses_for(x)) == 2
586
587  def test_Bidirectional_with_constants(self):
588    with self.cached_session():
589      # Test basic case.
590      x = keras.Input((5, 5))
591      c = keras.Input((3,))
592      cell = _RNNCellWithConstants(32)
593      custom_objects = {'_RNNCellWithConstants': _RNNCellWithConstants}
594      with keras.utils.CustomObjectScope(custom_objects):
595        layer = keras.layers.Bidirectional(keras.layers.RNN(cell))
596      y = layer(x, constants=c)
597      model = keras.Model([x, c], y)
598      model.compile(optimizer='rmsprop', loss='mse')
599      model.train_on_batch(
600          [np.zeros((6, 5, 5)), np.zeros((6, 3))],
601          np.zeros((6, 64))
602      )
603
604      # Test basic case serialization.
605      x_np = np.random.random((6, 5, 5))
606      c_np = np.random.random((6, 3))
607      y_np = model.predict([x_np, c_np])
608      weights = model.get_weights()
609      config = layer.get_config()
610
611      with keras.utils.CustomObjectScope(custom_objects):
612        layer = keras.layers.Bidirectional.from_config(copy.deepcopy(config))
613      y = layer(x, constants=c)
614      model = keras.Model([x, c], y)
615      model.set_weights(weights)
616      y_np_2 = model.predict([x_np, c_np])
617      self.assertAllClose(y_np, y_np_2, atol=1e-4)
618
619      # Test flat list inputs
620      with keras.utils.CustomObjectScope(custom_objects):
621        layer = keras.layers.Bidirectional.from_config(copy.deepcopy(config))
622      y = layer([x, c])
623      model = keras.Model([x, c], y)
624      model.set_weights(weights)
625      y_np_3 = model.predict([x_np, c_np])
626      self.assertAllClose(y_np, y_np_3, atol=1e-4)
627
628  def test_Bidirectional_with_constants_layer_passing_initial_state(self):
629    with self.cached_session():
630      # Test basic case.
631      x = keras.Input((5, 5))
632      c = keras.Input((3,))
633      s_for = keras.Input((32,))
634      s_bac = keras.Input((32,))
635      cell = _RNNCellWithConstants(32)
636      custom_objects = {'_RNNCellWithConstants': _RNNCellWithConstants}
637      with keras.utils.CustomObjectScope(custom_objects):
638        layer = keras.layers.Bidirectional(keras.layers.RNN(cell))
639      y = layer(x, initial_state=[s_for, s_bac], constants=c)
640      model = keras.Model([x, s_for, s_bac, c], y)
641      model.compile(optimizer='rmsprop', loss='mse')
642      model.train_on_batch(
643          [np.zeros((6, 5, 5)),
644           np.zeros((6, 32)),
645           np.zeros((6, 32)),
646           np.zeros((6, 3))],
647          np.zeros((6, 64))
648      )
649
650      # Test basic case serialization.
651      x_np = np.random.random((6, 5, 5))
652      s_fw_np = np.random.random((6, 32))
653      s_bk_np = np.random.random((6, 32))
654      c_np = np.random.random((6, 3))
655      y_np = model.predict([x_np, s_fw_np, s_bk_np, c_np])
656      weights = model.get_weights()
657      config = layer.get_config()
658
659      with keras.utils.CustomObjectScope(custom_objects):
660        layer = keras.layers.Bidirectional.from_config(copy.deepcopy(config))
661      y = layer(x, initial_state=[s_for, s_bac], constants=c)
662      model = keras.Model([x, s_for, s_bac, c], y)
663      model.set_weights(weights)
664      y_np_2 = model.predict([x_np, s_fw_np, s_bk_np, c_np])
665      self.assertAllClose(y_np, y_np_2, atol=1e-4)
666
667      # Verify that state is used
668      y_np_2_different_s = model.predict(
669          [x_np, s_fw_np + 10., s_bk_np + 10., c_np])
670      assert np.mean(y_np - y_np_2_different_s) != 0
671
672      # Test flat list inputs
673      with keras.utils.CustomObjectScope(custom_objects):
674        layer = keras.layers.Bidirectional.from_config(copy.deepcopy(config))
675      y = layer([x, s_for, s_bac, c])
676      model = keras.Model([x, s_for, s_bac, c], y)
677      model.set_weights(weights)
678      y_np_3 = model.predict([x_np, s_fw_np, s_bk_np, c_np])
679      self.assertAllClose(y_np, y_np_3, atol=1e-4)
680
681  def test_Bidirectional_last_output_with_masking(self):
682    rnn = keras.layers.LSTM
683    samples = 2
684    dim = 5
685    timesteps = 3
686    units = 3
687    merge_mode = 'concat'
688    x = np.random.rand(samples, timesteps, dim)
689    # clear the first record's timestep 2. Last output should be same as state,
690    # not zeroed.
691    x[0, 2] = 0
692
693    with self.cached_session():
694      inputs = keras.Input((timesteps, dim))
695      masked_inputs = keras.layers.Masking()(inputs)
696      wrapped = keras.layers.Bidirectional(
697          rnn(units, return_state=True), merge_mode=merge_mode)
698      outputs = _to_list(wrapped(masked_inputs, training=True))
699      self.assertEqual(len(outputs), 5)
700      self.assertEqual(outputs[0].get_shape().as_list(), [None, units * 2])
701
702      model = keras.Model(inputs, outputs)
703      y = _to_list(model.predict(x))
704      self.assertEqual(len(y), 5)
705      self.assertAllClose(y[0], np.concatenate([y[1], y[3]], axis=1))
706
707  def test_Bidirectional_sequence_output_with_masking(self):
708    rnn = keras.layers.LSTM
709    samples = 2
710    dim = 5
711    timesteps = 3
712    units = 3
713    merge_mode = 'concat'
714    x = np.random.rand(samples, timesteps, dim)
715    # clear the first record's timestep 2, and expect the output of timestep 2
716    # is also 0s.
717    x[0, 2] = 0
718
719    with self.cached_session():
720      inputs = keras.Input((timesteps, dim))
721      masked_inputs = keras.layers.Masking()(inputs)
722      wrapped = keras.layers.Bidirectional(
723          rnn(units, return_sequences=True),
724          merge_mode=merge_mode)
725      outputs = _to_list(wrapped(masked_inputs, training=True))
726      self.assertEqual(len(outputs), 1)
727      self.assertEqual(outputs[0].get_shape().as_list(),
728                       [None, timesteps, units * 2])
729
730      model = keras.Model(inputs, outputs)
731      y = _to_list(model.predict(x))
732      self.assertEqual(len(y), 1)
733      self.assertAllClose(y[0][0, 2], np.zeros(units * 2))
734
735
736def _to_list(ls):
737  if isinstance(ls, list):
738    return ls
739  else:
740    return [ls]
741
742
743if __name__ == '__main__':
744  test.main()
745