1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for sequence data preprocessing utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from math import ceil
22
23import numpy as np
24
25from tensorflow.python.keras.preprocessing import sequence as preprocessing_sequence
26from tensorflow.python.platform import test
27
28
29class TestSequence(test.TestCase):
30
31  def test_pad_sequences(self):
32    a = [[1], [1, 2], [1, 2, 3]]
33
34    # test padding
35    b = preprocessing_sequence.pad_sequences(a, maxlen=3, padding='pre')
36    self.assertAllClose(b, [[0, 0, 1], [0, 1, 2], [1, 2, 3]])
37    b = preprocessing_sequence.pad_sequences(a, maxlen=3, padding='post')
38    self.assertAllClose(b, [[1, 0, 0], [1, 2, 0], [1, 2, 3]])
39
40    # test truncating
41    b = preprocessing_sequence.pad_sequences(
42        a, maxlen=2, truncating='pre')
43    self.assertAllClose(b, [[0, 1], [1, 2], [2, 3]])
44    b = preprocessing_sequence.pad_sequences(
45        a, maxlen=2, truncating='post')
46    self.assertAllClose(b, [[0, 1], [1, 2], [1, 2]])
47
48    # test value
49    b = preprocessing_sequence.pad_sequences(a, maxlen=3, value=1)
50    self.assertAllClose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]])
51
52  def test_pad_sequences_vector(self):
53    a = [[[1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2], [3, 3]]]
54
55    # test padding
56    b = preprocessing_sequence.pad_sequences(a, maxlen=3, padding='pre')
57    self.assertAllClose(b, [[[0, 0], [0, 0], [1, 1]], [[0, 0], [2, 1], [2, 2]],
58                            [[3, 1], [3, 2], [3, 3]]])
59    b = preprocessing_sequence.pad_sequences(a, maxlen=3, padding='post')
60    self.assertAllClose(b, [[[1, 1], [0, 0], [0, 0]], [[2, 1], [2, 2], [0, 0]],
61                            [[3, 1], [3, 2], [3, 3]]])
62
63    # test truncating
64    b = preprocessing_sequence.pad_sequences(
65        a, maxlen=2, truncating='pre')
66    self.assertAllClose(b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 2], [3,
67                                                                          3]]])
68
69    b = preprocessing_sequence.pad_sequences(
70        a, maxlen=2, truncating='post')
71    self.assertAllClose(b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 1], [3,
72                                                                          2]]])
73
74    # test value
75    b = preprocessing_sequence.pad_sequences(a, maxlen=3, value=1)
76    self.assertAllClose(b, [[[1, 1], [1, 1], [1, 1]], [[1, 1], [2, 1], [2, 2]],
77                            [[3, 1], [3, 2], [3, 3]]])
78
79  def test_make_sampling_table(self):
80    a = preprocessing_sequence.make_sampling_table(3)
81    self.assertAllClose(
82        a, np.asarray([0.00315225, 0.00315225, 0.00547597]), rtol=.1)
83
84  def test_skipgrams(self):
85    # test with no window size and binary labels
86    couples, labels = preprocessing_sequence.skipgrams(
87        np.arange(3), vocabulary_size=3)
88    for couple in couples:
89      self.assertIn(couple[0], [0, 1, 2])
90      self.assertIn(couple[1], [0, 1, 2])
91
92    # test window size and categorical labels
93    couples, labels = preprocessing_sequence.skipgrams(
94        np.arange(5), vocabulary_size=5, window_size=1, categorical=True)
95    for couple in couples:
96      self.assertLessEqual(couple[0] - couple[1], 3)
97    for l in labels:
98      self.assertEqual(len(l), 2)
99
100  def test_remove_long_seq(self):
101    a = [[[1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2], [3, 3]]]
102
103    new_seq, new_label = preprocessing_sequence._remove_long_seq(
104        maxlen=3, seq=a, label=['a', 'b', ['c', 'd']])
105    self.assertEqual(new_seq, [[[1, 1]], [[2, 1], [2, 2]]])
106    self.assertEqual(new_label, ['a', 'b'])
107
108  def test_TimeseriesGenerator(self):
109    data = np.array([[i] for i in range(50)])
110    targets = np.array([[i] for i in range(50)])
111
112    data_gen = preprocessing_sequence.TimeseriesGenerator(
113        data, targets, length=10, sampling_rate=2, batch_size=2)
114    self.assertEqual(len(data_gen), 20)
115    self.assertAllClose(data_gen[0][0],
116                        np.array([[[0], [2], [4], [6], [8]], [[1], [3], [5],
117                                                              [7], [9]]]))
118    self.assertAllClose(data_gen[0][1], np.array([[10], [11]]))
119    self.assertAllClose(data_gen[1][0],
120                        np.array([[[2], [4], [6], [8], [10]], [[3], [5], [7],
121                                                               [9], [11]]]))
122    self.assertAllClose(data_gen[1][1], np.array([[12], [13]]))
123
124    data_gen = preprocessing_sequence.TimeseriesGenerator(
125        data, targets, length=10, sampling_rate=2, reverse=True, batch_size=2)
126    self.assertEqual(len(data_gen), 20)
127    self.assertAllClose(data_gen[0][0],
128                        np.array([[[8], [6], [4], [2], [0]], [[9], [7], [5],
129                                                              [3], [1]]]))
130    self.assertAllClose(data_gen[0][1], np.array([[10], [11]]))
131
132    data_gen = preprocessing_sequence.TimeseriesGenerator(
133        data, targets, length=10, sampling_rate=2, shuffle=True, batch_size=1)
134    batch = data_gen[0]
135    r = batch[1][0][0]
136    self.assertAllClose(batch[0],
137                        np.array([[[r - 10], [r - 8], [r - 6], [r - 4],
138                                   [r - 2]]]))
139    self.assertAllClose(batch[1], np.array([
140        [r],
141    ]))
142
143    data_gen = preprocessing_sequence.TimeseriesGenerator(
144        data, targets, length=10, sampling_rate=2, stride=2, batch_size=2)
145    self.assertEqual(len(data_gen), 10)
146    self.assertAllClose(data_gen[1][0],
147                        np.array([[[4], [6], [8], [10], [12]], [[6], [8], [10],
148                                                                [12], [14]]]))
149    self.assertAllClose(data_gen[1][1], np.array([[14], [16]]))
150
151    data_gen = preprocessing_sequence.TimeseriesGenerator(
152        data,
153        targets,
154        length=10,
155        sampling_rate=2,
156        start_index=10,
157        end_index=30,
158        batch_size=2)
159    self.assertEqual(len(data_gen), 6)
160    self.assertAllClose(data_gen[0][0],
161                        np.array([[[10], [12], [14], [16], [18]],
162                                  [[11], [13], [15], [17], [19]]]))
163    self.assertAllClose(data_gen[0][1], np.array([[20], [21]]))
164
165    data = np.array([np.random.random_sample((1, 2, 3, 4)) for i in range(50)])
166    targets = np.array([np.random.random_sample((3, 2, 1)) for i in range(50)])
167    data_gen = preprocessing_sequence.TimeseriesGenerator(
168        data,
169        targets,
170        length=10,
171        sampling_rate=2,
172        start_index=10,
173        end_index=30,
174        batch_size=2)
175
176    self.assertEqual(len(data_gen), 6)
177    self.assertAllClose(data_gen[0][0],
178                        np.array(
179                            [np.array(data[10:19:2]),
180                             np.array(data[11:20:2])]))
181    self.assertAllClose(data_gen[0][1], np.array([targets[20], targets[21]]))
182
183    with self.assertRaises(ValueError) as context:
184      preprocessing_sequence.TimeseriesGenerator(data, targets, length=50)
185    error = str(context.exception)
186    self.assertIn('`start_index+length=50 > end_index=49` is disallowed', error)
187
188  def test_TimeSeriesGenerator_doesnt_miss_any_sample(self):
189    x = np.array([[i] for i in range(10)])
190
191    for length in range(3, 10):
192      g = preprocessing_sequence.TimeseriesGenerator(
193          x, x, length=length, batch_size=1)
194      expected = max(0, len(x) - length)
195      actual = len(g)
196      self.assertEqual(expected, actual)
197
198      if actual > 0:
199        # All elements in range(length, 10) should be used as current step
200        expected = np.arange(length, 10).reshape(-1, 1)
201
202        y = np.concatenate([g[ix][1] for ix in range(len(g))], axis=0)
203        self.assertAllClose(y, expected)
204
205    x = np.array([[i] for i in range(23)])
206
207    strides = (1, 1, 5, 7, 3, 5, 3)
208    lengths = (3, 3, 4, 3, 1, 3, 7)
209    batch_sizes = (6, 6, 6, 5, 6, 6, 6)
210    shuffles = (False, True, True, False, False, False, False)
211
212    for stride, length, batch_size, shuffle in zip(strides, lengths,
213                                                   batch_sizes, shuffles):
214      g = preprocessing_sequence.TimeseriesGenerator(
215          x,
216          x,
217          length=length,
218          sampling_rate=1,
219          stride=stride,
220          start_index=0,
221          end_index=None,
222          shuffle=shuffle,
223          reverse=False,
224          batch_size=batch_size)
225      if shuffle:
226        # all batches have the same size when shuffle is True.
227        expected_sequences = ceil(
228            (23 - length) / float(batch_size * stride)) * batch_size
229      else:
230        # last batch will be different if `(samples - length) / stride`
231        # is not a multiple of `batch_size`.
232        expected_sequences = ceil((23 - length) / float(stride))
233
234      expected_batches = ceil(expected_sequences / float(batch_size))
235
236      y = [g[ix][1] for ix in range(len(g))]
237
238      actual_sequences = sum(len(iy) for iy in y)
239      actual_batches = len(y)
240
241      self.assertEqual(expected_sequences, actual_sequences)
242      self.assertEqual(expected_batches, actual_batches)
243
244
245if __name__ == '__main__':
246  test.main()
247